Do not add model prefix to normalization

This commit is contained in:
Pepijn
2025-09-13 11:25:26 +02:00
parent c5a029a28a
commit 5361346bec
2 changed files with 8 additions and 2 deletions

View File

@@ -919,7 +919,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
remap_count = 0
for key, value in fixed_state_dict.items():
if not key.startswith("model."):
if not key.startswith("model.") and not any(
key.startswith(prefix)
for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."]
):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1

View File

@@ -938,7 +938,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
remap_count = 0
for key, value in fixed_state_dict.items():
if not key.startswith("model."):
if not key.startswith("model.") and not any(
key.startswith(prefix)
for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."]
):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1