From fadb900c36db5600daca3d46dfee65b79c29934d Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 21 Feb 2026 18:19:12 +0100 Subject: [PATCH] compute before dist --- src/lerobot/scripts/lerobot_train.py | 111 +++++++++++++-------------- 1 file changed, 53 insertions(+), 58 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 50e8a8b6b..890225b37 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -215,6 +215,59 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info("Creating dataset") dataset = make_dataset(cfg) + # Compute delta action stats BEFORE distributed sync to avoid NCCL timeout + if getattr(cfg.policy, "use_delta_actions", False): + import numpy as np + + from lerobot.datasets.compute_stats import get_feature_stats + from lerobot.processor.delta_action_processor import to_delta_actions + + chunk_size = cfg.policy.chunk_size + hf = dataset.hf_dataset + total_frames = len(hf) + max_samples = min(500_000, total_frames - chunk_size) + indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False) + logging.info( + f"use_delta_actions is enabled — computing delta action stats " + f"from {max_samples} chunk samples (chunk_size={chunk_size})" + ) + + all_delta_actions = [] + episode_indices = np.array(hf["episode_index"]) + for idx in indices: + idx = int(idx) + ep_idx = episode_indices[idx] + end_idx = min(idx + chunk_size, total_frames) + if end_idx > idx and episode_indices[end_idx - 1] != ep_idx: + continue + + chunk_data = hf[idx:end_idx] + actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float() + state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float() + + mask = [True] * actions.shape[-1] + delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + all_delta_actions.append(delta.numpy()) + + all_delta = np.concatenate(all_delta_actions, axis=0) + delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1) + dataset.meta.stats["action"] = delta_stats + + norm_type = "UNKNOWN" + if hasattr(cfg.policy, "normalization_mapping"): + from lerobot.configs.types import NormalizationMode + action_norm = cfg.policy.normalization_mapping.get("ACTION", None) + norm_type = action_norm.value if action_norm else "UNKNOWN" + + logging.info( + f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): " + f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, " + f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}" + ) + if norm_type == "QUANTILES": + q_range = (delta_stats['q99'] - delta_stats['q01']).mean() + logging.info(f" Quantile range (q99-q01): {q_range:.4f}") + accelerator.wait_for_everyone() # Now all other processes can safely load the dataset @@ -243,64 +296,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): peft_cli_overrides = dataclasses.asdict(cfg.peft) policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) - # Recompute action stats as delta if use_delta_actions is enabled. - # Must build action CHUNKS (like the model sees) and subtract state from each chunk. - # hf_dataset stores per-frame data; we manually assemble chunks to match delta_timestamps. - if getattr(cfg.policy, "use_delta_actions", False) and is_main_process: - import numpy as np - - from lerobot.datasets.compute_stats import get_feature_stats - from lerobot.processor.delta_action_processor import to_delta_actions - - chunk_size = cfg.policy.chunk_size - hf = dataset.hf_dataset - total_frames = len(hf) - max_samples = min(1_000_000, total_frames - chunk_size) - indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False) - logging.info( - f"use_delta_actions is enabled — computing delta action stats " - f"from {max_samples} chunk samples (chunk_size={chunk_size})" - ) - - # Build chunks: for each index i, read actions[i:i+chunk_size] and state[i] - all_delta_actions = [] - episode_indices = np.array(hf["episode_index"]) - for idx in indices: - idx = int(idx) - # Ensure chunk doesn't cross episode boundary - ep_idx = episode_indices[idx] - end_idx = min(idx + chunk_size, total_frames) - if end_idx > idx and episode_indices[end_idx - 1] != ep_idx: - continue - - chunk_data = hf[idx:end_idx] - actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float() - state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float() - - mask = [True] * actions.shape[-1] - delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) - all_delta_actions.append(delta.numpy()) - - all_delta = np.concatenate(all_delta_actions, axis=0) - delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1) - dataset.meta.stats["action"] = delta_stats - - # Determine normalization type for logging - norm_type = "UNKNOWN" - if hasattr(cfg.policy, "normalization_mapping"): - from lerobot.configs.types import NormalizationMode - action_norm = cfg.policy.normalization_mapping.get("ACTION", None) - norm_type = action_norm.value if action_norm else "UNKNOWN" - - logging.info( - f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): " - f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, " - f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}" - ) - if norm_type == "QUANTILES": - q_range = (delta_stats['q99'] - delta_stats['q01']).mean() - logging.info(f" Quantile range (q99-q01): {q_range:.4f}") - # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone()