diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 60a4d81d5..f2d07cd7f 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -63,6 +63,10 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + # Accelerate configuration for multi-GPU training + use_accelerate: bool = False + gradient_accumulation_steps: int = 1 + mixed_precision: str = "no" # Options: "no", "fp16", "bf16" def __post_init__(self): self.checkpoint_path = None diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 21da62bbb..aed53a53c 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -20,6 +20,7 @@ from pprint import pformat from typing import Any import torch +from accelerate import Accelerator from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer @@ -64,6 +65,7 @@ def update_policy( lr_scheduler=None, use_amp: bool = False, lock=None, + accelerator: Accelerator = None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -90,28 +92,48 @@ def update_policy( start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() - with torch.autocast(device_type=device.type) if use_amp else nullcontext(): - loss, output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - grad_scaler.scale(loss).backward() - # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) + # Handle mixed precision differently for accelerate vs non-accelerate + if accelerator is not None: + # Accelerate handles mixed precision internally + with accelerator.autocast() if use_amp else nullcontext(): + loss, output_dict = policy.forward(batch) + # Use accelerator's backward method + accelerator.backward(loss) + else: + # Original behavior for non-accelerate + with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + loss, output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + grad_scaler.scale(loss).backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) + if accelerator is not None: + # Accelerate handles gradient scaling internally + grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) + if grad_norm is None: + grad_norm = 0.0 + with lock if lock is not None else nullcontext(): + optimizer.step() + optimizer.zero_grad() + else: + # Original gradient handling for non-accelerate + # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. - with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) - optimizer.zero_grad() + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + with lock if lock is not None else nullcontext(): + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() + + optimizer.zero_grad() # Step through pytorch scheduler at every batch instead of epoch if lr_scheduler is not None: @@ -147,6 +169,19 @@ def train(cfg: TrainPipelineConfig): cfg.validate() logging.info(pformat(cfg.to_dict())) + # Initialize Accelerate if requested + accelerator = None + if cfg.use_accelerate: + accelerator = Accelerator( + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + mixed_precision=cfg.mixed_precision, + ) + device = accelerator.device + logging.info(f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}") + else: + # Check device is available (original behavior) + device = get_safe_torch_device(cfg.policy.device, log=True) + if cfg.wandb.enable and cfg.wandb.project: wandb_logger = WandBLogger(cfg) else: @@ -156,8 +191,6 @@ def train(cfg: TrainPipelineConfig): if cfg.seed is not None: set_seed(cfg.seed) - # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -198,6 +231,13 @@ def train(cfg: TrainPipelineConfig): step = 0 # number of policy updates (forward + backward + optim) if cfg.resume: + if accelerator is not None: + # Load accelerate-specific state if available + accelerate_state_path = cfg.checkpoint_path / "accelerate_state" + if accelerate_state_path.exists(): + accelerator.load_state(str(accelerate_state_path)) + logging.info("Loaded Accelerate state from checkpoint") + step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) @@ -235,6 +275,14 @@ def train(cfg: TrainPipelineConfig): drop_last=False, prefetch_factor=2, ) + + # Prepare objects with Accelerate if enabled + if accelerator is not None: + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler + ) + logging.info("Policy, optimizer, dataloader, and scheduler prepared with Accelerate") + dl_iter = cycle(dataloader) policy.train() @@ -253,21 +301,42 @@ def train(cfg: TrainPipelineConfig): logging.info("Start offline training on a fixed dataset") for _ in range(step, cfg.steps): - start_time = time.perf_counter() - batch = next(dl_iter) - batch = preprocessor(batch) - train_tracker.dataloading_s = time.perf_counter() - start_time + # Handle gradient accumulation + if accelerator is not None: + with accelerator.accumulate(policy): + start_time = time.perf_counter() + batch = next(dl_iter) + batch = preprocessor(batch) + train_tracker.dataloading_s = time.perf_counter() - start_time - train_tracker, output_dict = update_policy( - train_tracker, - policy, - batch, - optimizer, - cfg.optimizer.grad_clip_norm, - grad_scaler=grad_scaler, - lr_scheduler=lr_scheduler, - use_amp=cfg.policy.use_amp, - ) + train_tracker, output_dict = update_policy( + train_tracker, + policy, + batch, + optimizer, + cfg.optimizer.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.policy.use_amp, + accelerator=accelerator, + ) + else: + start_time = time.perf_counter() + batch = next(dl_iter) + batch = preprocessor(batch) + train_tracker.dataloading_s = time.perf_counter() - start_time + + train_tracker, output_dict = update_policy( + train_tracker, + policy, + batch, + optimizer, + cfg.optimizer.grad_clip_norm, + grad_scaler=grad_scaler, + lr_scheduler=lr_scheduler, + use_amp=cfg.policy.use_amp, + accelerator=accelerator, + ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. @@ -289,63 +358,101 @@ def train(cfg: TrainPipelineConfig): if cfg.save_checkpoint and is_saving_step: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint( - checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor - ) - update_last_checkpoint(checkpoint_dir) - if wandb_logger: - wandb_logger.log_policy(checkpoint_dir) + + if accelerator is not None: + # Use accelerate's checkpointing - only saves on main process + accelerator.wait_for_everyone() # Synchronize all processes + if accelerator.is_main_process: + # Use unwrapped model for saving + unwrapped_policy = accelerator.unwrap_model(policy) + save_checkpoint( + checkpoint_dir, + step, + cfg, + unwrapped_policy, + optimizer, + lr_scheduler, + preprocessor, + postprocessor, + ) + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) + # Save accelerate-specific state + accelerator.save_state(checkpoint_dir / "accelerate_state") + else: + # Original behavior for non-accelerate + save_checkpoint( + checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor + ) + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) if cfg.env and is_eval_step: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") - with ( - torch.no_grad(), - torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), - ): - eval_info = eval_policy_all( - envs=eval_env, # dict[suite][task_id] -> vec_env - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - n_episodes=cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - max_parallel_tasks=cfg.env.max_parallel_tasks, + # Only evaluate on main process when using accelerate + if accelerator is None or accelerator.is_main_process: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + + # Use unwrapped model for evaluation if using accelerate + eval_policy = accelerator.unwrap_model(policy) if accelerator is not None else policy + + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), + ): + eval_info = eval_policy_all( + envs=eval_env, # dict[suite][task_id] -> vec_env + policy=eval_policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + ) + # overall metrics (suite-agnostic) + aggregated = eval_info["overall"] + + # optional: per-suite logging + for suite, suite_info in eval_info.items(): + logging.info("Suite %s aggregated: %s", suite, suite_info) + + # meters/tracker + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step ) - # overall metrics (suite-agnostic) - aggregated = eval_info["overall"] + eval_tracker.eval_s = aggregated.pop("eval_s") + eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") + eval_tracker.pc_success = aggregated.pop("pc_success") + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") + wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") - # optional: per-suite logging - for suite, suite_info in eval_info.items(): - logging.info("Suite %s aggregated: %s", suite, suite_info) - - # meters/tracker - eval_metrics = { - "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), - } - eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step - ) - eval_tracker.eval_s = aggregated.pop("eval_s") - eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") - eval_tracker.pc_success = aggregated.pop("pc_success") - if wandb_logger: - wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} - wandb_logger.log_dict(wandb_log_dict, step, mode="eval") - wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") + # Synchronize all processes after evaluation + if accelerator is not None: + accelerator.wait_for_everyone() if eval_env: close_envs(eval_env) logging.info("End of training") if cfg.policy.push_to_hub: - policy.push_model_to_hub(cfg) - preprocessor.push_to_hub(cfg.policy.repo_id) - postprocessor.push_to_hub(cfg.policy.repo_id) + # Only push to hub from main process when using accelerate + if accelerator is None or accelerator.is_main_process: + # Use unwrapped model for hub pushing if using accelerate + hub_policy = accelerator.unwrap_model(policy) if accelerator is not None else policy + hub_policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) + postprocessor.push_to_hub(cfg.policy.repo_id) def main():