diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index f88d5c02f..405d619c8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -304,24 +304,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ) # Compute per-timestep normalizer for relative actions - # Each process computes stats independently to avoid distributed sync issues + # Only main process computes, then broadcasts to avoid video decoder issues relative_normalizer = None if cfg.use_relative_actions: mode = "actions + state" if cfg.use_relative_state else "actions only" if is_main_process: logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"])) logging.info("Computing per-timestep stats from dataset (first 1000 batches)...") - - temp_loader = torch.utils.data.DataLoader( - dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 - ) - mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) - relative_normalizer = PerTimestepNormalizer(mean, std) - - if is_main_process: + temp_loader = torch.utils.data.DataLoader( + dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 + ) + mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) + del temp_loader cfg.output_dir.mkdir(parents=True, exist_ok=True) - relative_normalizer.save(cfg.output_dir / "relative_stats.pt") - logging.info(f"Saved stats to: {cfg.output_dir / 'relative_stats.pt'}") + stats_path = cfg.output_dir / "relative_stats.pt" + torch.save({"mean": mean, "std": std}, stats_path) + logging.info(f"Saved stats to: {stats_path}") + + accelerator.wait_for_everyone() + + # All ranks load from saved file + stats_path = cfg.output_dir / "relative_stats.pt" + data = torch.load(stats_path, weights_only=True, map_location="cpu") + relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"]) step = 0 # number of policy updates (forward + backward + optim)