mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
add model. prefix to all keys in state dict
This commit is contained in:
@@ -871,14 +871,109 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""Override the from_pretrained method to display important disclaimer."""
|
||||
def from_pretrained(
|
||||
cls, *args, **kwargs
|
||||
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
# Store original strict mode
|
||||
original_strict = kwargs.get("strict", True)
|
||||
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
|
||||
kwargs["strict"] = False
|
||||
|
||||
# Call parent from_pretrained with strict=False
|
||||
model = super().from_pretrained(*args, **kwargs)
|
||||
|
||||
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
|
||||
if len(args) > 0:
|
||||
pretrained_model_name_or_path = args[0]
|
||||
elif "pretrained_model_name_or_path" in kwargs:
|
||||
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
|
||||
else:
|
||||
return model
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
from transformers.utils import cached_file
|
||||
|
||||
# Try to load the pytorch_model.bin or model.safetensors file
|
||||
print(f"Loading model from: {pretrained_model_name_or_path}")
|
||||
try:
|
||||
# Try safetensors first
|
||||
resolved_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
"model.safetensors",
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
force_download=kwargs.get("force_download", False),
|
||||
resume_download=kwargs.get("resume_download"),
|
||||
proxies=kwargs.get("proxies"),
|
||||
use_auth_token=kwargs.get("use_auth_token"),
|
||||
revision=kwargs.get("revision"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
from safetensors.torch import load_file
|
||||
|
||||
original_state_dict = load_file(resolved_file)
|
||||
print("✓ Loaded state dict from model.safetensors")
|
||||
except Exception as e:
|
||||
print(f"Could not load state dict from remote files: {e}")
|
||||
return model
|
||||
|
||||
# Create a new state dict with "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
|
||||
for key, value in original_state_dict.items():
|
||||
if not key.startswith("model."):
|
||||
new_key = f"model.{key}"
|
||||
remapped_state_dict[new_key] = value
|
||||
remap_count += 1
|
||||
if remap_count <= 10: # Only print first 10 to avoid spam
|
||||
print(f"Remapped: {key} -> {new_key}")
|
||||
else:
|
||||
remapped_state_dict[key] = value
|
||||
|
||||
if remap_count > 10:
|
||||
print(f"... and {remap_count - 10} more keys remapped")
|
||||
|
||||
print(f"Total keys remapped: {remap_count}")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✅ All keys loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
|
||||
print("Using default loading behavior")
|
||||
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params`
|
||||
return self.parameters()
|
||||
|
||||
Reference in New Issue
Block a user