From 8abc9037a3f3faab28ca04fe204a0d823bd41bc1 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 21 Feb 2026 07:41:42 +0100 Subject: [PATCH] sample 100k --- src/lerobot/scripts/lerobot_train.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 53274ed49..08d2a6ac8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -246,31 +246,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Recompute action stats as delta if use_delta_actions is enabled. # Must iterate the actual dataset (which returns action chunks via delta_timestamps) # so stats capture the full range of chunk-level deltas, not just per-frame deltas. + # We sample a subset for speed — 100K frames is sufficient for accurate stats. if getattr(cfg.policy, "use_delta_actions", False) and is_main_process: - logging.info("use_delta_actions is enabled — computing delta action stats from dataset chunks") + import numpy as np + from lerobot.datasets.compute_stats import get_feature_stats from lerobot.processor.delta_action_processor import to_delta_actions + max_samples = min(100_000, len(dataset)) + indices = np.random.choice(len(dataset), max_samples, replace=False) + logging.info( + f"use_delta_actions is enabled — computing delta action stats from {max_samples} dataset chunks" + ) + all_delta_actions = [] - for i in range(len(dataset)): - item = dataset[i] + for i in indices: + item = dataset[int(i)] action = item["action"] state = item["observation.state"] - # action may be (chunk_size, action_dim) or (action_dim,) if action.ndim == 1: action = action.unsqueeze(0) mask = [True] * action.shape[-1] delta = to_delta_actions(action.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) all_delta_actions.append(delta.numpy()) - import numpy as np - all_delta = np.concatenate(all_delta_actions, axis=0) delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1) dataset.meta.stats["action"] = delta_stats logging.info( - f"Delta action stats computed from {len(dataset)} samples: " - f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}" + f"Delta action stats: mean={np.abs(delta_stats['mean']).mean():.4f}, " + f"std={delta_stats['std'].mean():.4f}" ) # Wait for all processes to finish policy creation before continuing