diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index deb5a4681..a9c0ddb54 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -324,7 +324,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): temp_loader = torch.utils.data.DataLoader( stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 ) - mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) + reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {} + raw_state_key = reverse_rename.get("observation.state", "observation.state") + mean, std = compute_relative_action_stats(temp_loader, state_key=raw_state_key, num_batches=1000) del temp_loader, stats_dataset gc.collect() torch.save({"mean": mean, "std": std}, stats_path)