diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index 9ff71152a..1f507c75d 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -919,7 +919,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): remap_count = 0 for key, value in fixed_state_dict.items(): - if not key.startswith("model."): + if not key.startswith("model.") and not any( + key.startswith(prefix) + for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."] + ): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 549dc0a9b..1fdb6048b 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -938,7 +938,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy): remap_count = 0 for key, value in fixed_state_dict.items(): - if not key.startswith("model."): + if not key.startswith("model.") and not any( + key.startswith(prefix) + for prefix in ["normalize_inputs.", "normalize_targets.", "unnormalize_outputs."] + ): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1