[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-09-11 11:51:53 +00:00
parent 565c992589
commit a19d7fb6bf
17 changed files with 469 additions and 254 deletions

View File

@@ -24,12 +24,15 @@ from accelerate.utils import set_seed as accelerate_set_seed
from termcolor import colored
from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.scripts.eval import eval_policy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
@@ -43,9 +46,6 @@ from lerobot.utils.utils import (
has_method,
init_logging,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
@@ -100,6 +100,7 @@ def train(cfg: TrainPipelineConfig):
# Initialize accelerator
from accelerate.utils import DistributedDataParallelKwargs
# added by jade 2 lines
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
@@ -357,7 +358,7 @@ def train(cfg: TrainPipelineConfig):
if accelerator.is_main_process:
logging.info("End of training")
accelerator.end_training() # added by jade
accelerator.end_training() # added by jade
if __name__ == "__main__":