diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 756d204f1..5b73b726c 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -95,7 +95,8 @@ def update_policy( start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() - with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext(): + # Let accelerator handle mixed precision + with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) @@ -217,6 +218,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): ) policy.to(device) + # Wait for all processes to finish policy creation before continuing + if accelerator: + accelerator.wait_for_everyone() + # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} postprocessor_kwargs = {} @@ -299,6 +304,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): prefetch_factor=2 if cfg.num_workers > 0 else None, ) if accelerator: + accelerator.wait_for_everyone() policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( policy, optimizer, dataloader, lr_scheduler ) @@ -380,8 +386,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): if wandb_logger: wandb_logger.log_policy(checkpoint_dir) - if accelerator: - accelerator.wait_for_everyone() + if accelerator: + accelerator.wait_for_everyone() + if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") @@ -431,6 +438,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + if accelerator: + accelerator.wait_for_everyone() + if eval_env: close_envs(eval_env)