fix data loader issue

This commit is contained in:
Pepijn
2026-01-07 10:03:56 +01:00
parent 574081ac02
commit 63619619bf

View File

@@ -304,24 +304,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
# Compute per-timestep normalizer for relative actions
# Each process computes stats independently to avoid distributed sync issues
# Only main process computes, then broadcasts to avoid video decoder issues
relative_normalizer = None
if cfg.use_relative_actions:
mode = "actions + state" if cfg.use_relative_state else "actions only"
if is_main_process:
logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"]))
logging.info("Computing per-timestep stats from dataset (first 1000 batches)...")
temp_loader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
)
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
relative_normalizer = PerTimestepNormalizer(mean, std)
if is_main_process:
temp_loader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
)
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
del temp_loader
cfg.output_dir.mkdir(parents=True, exist_ok=True)
relative_normalizer.save(cfg.output_dir / "relative_stats.pt")
logging.info(f"Saved stats to: {cfg.output_dir / 'relative_stats.pt'}")
stats_path = cfg.output_dir / "relative_stats.pt"
torch.save({"mean": mean, "std": std}, stats_path)
logging.info(f"Saved stats to: {stats_path}")
accelerator.wait_for_everyone()
# All ranks load from saved file
stats_path = cfg.output_dir / "relative_stats.pt"
data = torch.load(stats_path, weights_only=True, map_location="cpu")
relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"])
step = 0 # number of policy updates (forward + backward + optim)