mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fix relative stats computation with rename_map
When rename_map maps a dataset key to observation.state, the raw dataset used for stats computation still has the original key. Reverse the rename_map to find the correct key. Made-with: Cursor
This commit is contained in:
@@ -324,7 +324,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
temp_loader = torch.utils.data.DataLoader(
|
||||
stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
|
||||
)
|
||||
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
|
||||
reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {}
|
||||
raw_state_key = reverse_rename.get("observation.state", "observation.state")
|
||||
mean, std = compute_relative_action_stats(temp_loader, state_key=raw_state_key, num_batches=1000)
|
||||
del temp_loader, stats_dataset
|
||||
gc.collect()
|
||||
torch.save({"mean": mean, "std": std}, stats_path)
|
||||
|
||||
Reference in New Issue
Block a user