pass stats

This commit is contained in:
Pepijn
2025-11-25 16:25:58 +01:00
parent 6b6a82bbdf
commit 3b31c2d9d3

View File

@@ -433,6 +433,9 @@ def make_policy(
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, 'stats'):
kwargs["dataset_stats"] = ds_meta.stats
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time