diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index f58d13f4b..fb6633336 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -21,6 +21,7 @@ from typing import Any import torch from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer @@ -176,9 +177,12 @@ def train(cfg: TrainPipelineConfig): # Initialize Accelerate if requested accelerator = None if cfg.use_accelerate: + # Configure DDP to handle unused parameters + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=cfg.gradient_accumulation_steps, mixed_precision=cfg.mixed_precision, + kwargs_handlers=[ddp_kwargs], ) device = accelerator.device if accelerator.is_main_process: