feat(processor): convert openpi model with processor

This commit is contained in:
AdilZouitine
2025-09-19 15:48:35 +02:00
parent d691d1e4fe
commit 10f5ea854f
8 changed files with 481 additions and 174 deletions

View File

@@ -302,6 +302,65 @@ def clean_state_dict(
return new_state_dict
def load_state_dict_with_missing_key_handling(
policy: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
policy_type: str,
known_missing_keys_whitelist: dict[str, list[str]],
) -> list[str]:
"""
Load state dict into policy with graceful handling of missing keys.
This function loads the state dict with strict=False, filters out whitelisted
missing keys, and provides detailed reporting about any issues found.
Args:
policy: The policy model to load the state dict into.
state_dict: The cleaned state dictionary to load.
policy_type: The type of policy (used for whitelist lookup).
known_missing_keys_whitelist: Dictionary mapping policy types to lists of
known acceptable missing keys.
Returns:
List of problematic missing keys that weren't in the whitelist.
"""
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
load_result = policy.load_state_dict(state_dict, strict=False)
# Check for missing keys
missing_keys = load_result.missing_keys
unexpected_keys = load_result.unexpected_keys
# Filter out whitelisted missing keys
policy_type_lower = policy_type.lower()
whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
if missing_keys:
if problematic_missing_keys:
print(f"⚠️ WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
for key in problematic_missing_keys:
print(f" - {key}")
if len(missing_keys) > len(problematic_missing_keys):
whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
print(f" INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
for key in whitelisted_missing:
print(f" - {key}")
if unexpected_keys:
print(f"⚠️ WARNING: Found {len(unexpected_keys)} unexpected keys:")
for key in unexpected_keys:
print(f" - {key}")
if not missing_keys and not unexpected_keys:
print("✅ Successfully loaded cleaned state dict into policy model (all keys matched)")
else:
print("⚠️ State dict loaded with some missing/unexpected keys (see details above)")
return problematic_missing_keys
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
@@ -335,9 +394,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
return converted_features
def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
"""
Display final migration summary with warnings about problematic missing keys.
Args:
problematic_missing_keys: List of missing keys that weren't in the whitelist.
"""
if not problematic_missing_keys:
return
print("\n" + "=" * 60)
print("🚨 IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
print("=" * 60)
print(
f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
)
print()
for key in problematic_missing_keys:
print(f"{key}")
print()
print("These missing keys may indicate:")
print(" • The model architecture has changed")
print(" • Some components were not properly saved in the original model")
print(" • The migration script needs to be updated for this policy type")
print()
print("What to do next:")
print(" 1. Test your migrated model carefully to ensure it works as expected")
print(" 2. If you encounter issues, please open an issue at:")
print(" https://github.com/huggingface/lerobot/issues")
print(" 3. Include this migration log and the missing keys listed above")
print()
print("If the model works correctly despite these warnings, the missing keys")
print("might be expected for your policy type and can be added to the whitelist.")
print("=" * 60)
def load_model_from_hub(
repo_id: str, revision: str | None = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
"""
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
@@ -347,13 +442,12 @@ def load_model_from_hub(
Returns:
A tuple containing the model's state dictionary, the policy configuration,
and the training configuration.
and the training configuration (None if train_config.json is not found).
"""
# Download files.
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
@@ -362,8 +456,14 @@ def load_model_from_hub(
with open(config_path) as f:
config = json.load(f)
with open(train_config_path) as f:
train_config = json.load(f)
# Try to load train_config (optional)
train_config = None
try:
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
with open(train_config_path) as f:
train_config = json.load(f)
except FileNotFoundError:
print("train_config.json not found - continuing without training configuration")
return state_dict, config, train_config
@@ -409,8 +509,15 @@ def main():
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
# Try to load train_config (optional)
train_config = None
train_config_path = os.path.join(args.pretrained_path, "train_config.json")
if os.path.exists(train_config_path):
with open(train_config_path) as f:
train_config = json.load(f)
else:
print("train_config.json not found - continuing without training configuration")
else:
# Hub repository
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
@@ -487,10 +594,20 @@ def main():
policy_class = get_policy_class(policy_type)
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
known_missing_keys_whitelist = {
"pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
# Add other policy types and their known missing keys here as needed
}
# Load state dict with graceful missing key handling
problematic_missing_keys = load_state_dict_with_missing_key_handling(
policy=policy,
state_dict=new_state_dict,
policy_type=policy_type,
known_missing_keys_whitelist=known_missing_keys_whitelist,
)
policy.to(torch.float32)
# Create preprocessor and postprocessor using the factory
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
@@ -520,7 +637,9 @@ def main():
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = train_config.get("repo_id", "unknown")
dataset_repo_id = "unknown"
if train_config is not None:
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
@@ -641,6 +760,9 @@ final_action = postprocessor(action)
else:
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
# Display final summary about any problematic missing keys
display_migration_summary_with_warnings(problematic_missing_keys)
if __name__ == "__main__":
main()