mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
@@ -43,9 +46,6 @@ from lerobot.utils.utils import (
|
||||
has_method,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(
|
||||
@@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
# Initialize accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# added by jade 2 lines
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
||||
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
|
||||
@@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
logging.info("End of training")
|
||||
accelerator.end_training() # added by jade
|
||||
accelerator.end_training() # added by jade
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user