sample 100k

This commit is contained in:
Pepijn
2026-02-21 07:41:42 +01:00
parent e79b2a439b
commit 8abc9037a3

View File

@@ -246,31 +246,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Recompute action stats as delta if use_delta_actions is enabled.
# Must iterate the actual dataset (which returns action chunks via delta_timestamps)
# so stats capture the full range of chunk-level deltas, not just per-frame deltas.
# We sample a subset for speed — 100K frames is sufficient for accurate stats.
if getattr(cfg.policy, "use_delta_actions", False) and is_main_process:
logging.info("use_delta_actions is enabled — computing delta action stats from dataset chunks")
import numpy as np
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.processor.delta_action_processor import to_delta_actions
max_samples = min(100_000, len(dataset))
indices = np.random.choice(len(dataset), max_samples, replace=False)
logging.info(
f"use_delta_actions is enabled — computing delta action stats from {max_samples} dataset chunks"
)
all_delta_actions = []
for i in range(len(dataset)):
item = dataset[i]
for i in indices:
item = dataset[int(i)]
action = item["action"]
state = item["observation.state"]
# action may be (chunk_size, action_dim) or (action_dim,)
if action.ndim == 1:
action = action.unsqueeze(0)
mask = [True] * action.shape[-1]
delta = to_delta_actions(action.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_delta_actions.append(delta.numpy())
import numpy as np
all_delta = np.concatenate(all_delta_actions, axis=0)
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
dataset.meta.stats["action"] = delta_stats
logging.info(
f"Delta action stats computed from {len(dataset)} samples: "
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}"
f"Delta action stats: mean={np.abs(delta_stats['mean']).mean():.4f}, "
f"std={delta_stats['std'].mean():.4f}"
)
# Wait for all processes to finish policy creation before continuing