add model. prefix to all keys in state dict

This commit is contained in:
Pepijn
2025-09-10 20:35:19 +02:00
parent b3b57a8288
commit 2eafcc7ca1

View File

@@ -871,14 +871,109 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
self.reset()
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
def from_pretrained(
cls, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)
# Store original strict mode
original_strict = kwargs.get("strict", True)
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
kwargs["strict"] = False
# Call parent from_pretrained with strict=False
model = super().from_pretrained(*args, **kwargs)
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
if len(args) > 0:
pretrained_model_name_or_path = args[0]
elif "pretrained_model_name_or_path" in kwargs:
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
else:
return model
# Now manually load and remap the state dict
try:
from transformers.utils import cached_file
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_model_name_or_path}")
try:
# Try safetensors first
resolved_file = cached_file(
pretrained_model_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
from safetensors.torch import load_file
original_state_dict = load_file(resolved_file)
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
return model
# Create a new state dict with "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
for key, value in original_state_dict.items():
if not key.startswith("model."):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
if remap_count > 10:
print(f"... and {remap_count - 10} more keys remapped")
print(f"Total keys remapped: {remap_count}")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
if missing_keys:
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("✅ All keys loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
print("Using default loading behavior")
return model
def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params`
return self.parameters()