diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 353549c52..baff3b15f 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -871,14 +871,109 @@ class PI0OpenPIPolicy(PreTrainedPolicy): self.reset() @classmethod - def from_pretrained(cls, *args, **kwargs): - """Override the from_pretrained method to display important disclaimer.""" + def from_pretrained( + cls, *args, **kwargs + ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" " This implementation follows the original OpenPI structure for compatibility. \n" " Original implementation: https://github.com/Physical-Intelligence/openpi" ) - return super().from_pretrained(*args, **kwargs) + + # Store original strict mode + original_strict = kwargs.get("strict", True) + # Temporarily set strict=False to avoid loading issues, we'll handle it manually + kwargs["strict"] = False + + # Call parent from_pretrained with strict=False + model = super().from_pretrained(*args, **kwargs) + + # Extract the pretrained_model_name_or_path from args or kwargs for remapping + if len(args) > 0: + pretrained_model_name_or_path = args[0] + elif "pretrained_model_name_or_path" in kwargs: + pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"] + else: + return model + + # Now manually load and remap the state dict + try: + from transformers.utils import cached_file + + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_model_name_or_path}") + try: + # Try safetensors first + resolved_file = cached_file( + pretrained_model_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + return model + + # Create a new state dict with "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in original_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 10: + print(f"... and {remap_count - 10} more keys remapped") + + print(f"Total keys remapped: {remap_count}") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict) + + if missing_keys: + print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("✅ All keys loaded successfully!") + + except Exception as e: + print(f"⚠️ Warning: Could not remap state dict keys: {e}") + print("Using default loading behavior") + + return model def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params` return self.parameters()