fix relative stats computation with rename_map

When rename_map maps a dataset key to observation.state, the raw
dataset used for stats computation still has the original key.
Reverse the rename_map to find the correct key.

Made-with: Cursor
This commit is contained in:
pepijn
2026-04-01 19:12:14 +00:00
parent 5a15a6a911
commit 900f1a42e9

View File

@@ -324,7 +324,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
temp_loader = torch.utils.data.DataLoader(
stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
)
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {}
raw_state_key = reverse_rename.get("observation.state", "observation.state")
mean, std = compute_relative_action_stats(temp_loader, state_key=raw_state_key, num_batches=1000)
del temp_loader, stats_dataset
gc.collect()
torch.save({"mean": mean, "std": std}, stats_path)