mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
feat(processor): convert openpi model with processor
This commit is contained in:
@@ -302,6 +302,65 @@ def clean_state_dict(
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_state_dict_with_missing_key_handling(
|
||||
policy: torch.nn.Module,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
policy_type: str,
|
||||
known_missing_keys_whitelist: dict[str, list[str]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Load state dict into policy with graceful handling of missing keys.
|
||||
|
||||
This function loads the state dict with strict=False, filters out whitelisted
|
||||
missing keys, and provides detailed reporting about any issues found.
|
||||
|
||||
Args:
|
||||
policy: The policy model to load the state dict into.
|
||||
state_dict: The cleaned state dictionary to load.
|
||||
policy_type: The type of policy (used for whitelist lookup).
|
||||
known_missing_keys_whitelist: Dictionary mapping policy types to lists of
|
||||
known acceptable missing keys.
|
||||
|
||||
Returns:
|
||||
List of problematic missing keys that weren't in the whitelist.
|
||||
"""
|
||||
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
|
||||
load_result = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Check for missing keys
|
||||
missing_keys = load_result.missing_keys
|
||||
unexpected_keys = load_result.unexpected_keys
|
||||
|
||||
# Filter out whitelisted missing keys
|
||||
policy_type_lower = policy_type.lower()
|
||||
whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
|
||||
problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
|
||||
|
||||
if missing_keys:
|
||||
if problematic_missing_keys:
|
||||
print(f"⚠️ WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
|
||||
for key in problematic_missing_keys:
|
||||
print(f" - {key}")
|
||||
|
||||
if len(missing_keys) > len(problematic_missing_keys):
|
||||
whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
|
||||
print(f"ℹ️ INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
|
||||
for key in whitelisted_missing:
|
||||
print(f" - {key}")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"⚠️ WARNING: Found {len(unexpected_keys)} unexpected keys:")
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✅ Successfully loaded cleaned state dict into policy model (all keys matched)")
|
||||
else:
|
||||
print("⚠️ State dict loaded with some missing/unexpected keys (see details above)")
|
||||
|
||||
return problematic_missing_keys
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||
@@ -335,9 +394,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
return converted_features
|
||||
|
||||
|
||||
def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
|
||||
"""
|
||||
Display final migration summary with warnings about problematic missing keys.
|
||||
|
||||
Args:
|
||||
problematic_missing_keys: List of missing keys that weren't in the whitelist.
|
||||
"""
|
||||
if not problematic_missing_keys:
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🚨 IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
|
||||
)
|
||||
print()
|
||||
for key in problematic_missing_keys:
|
||||
print(f" ❌ {key}")
|
||||
print()
|
||||
print("These missing keys may indicate:")
|
||||
print(" • The model architecture has changed")
|
||||
print(" • Some components were not properly saved in the original model")
|
||||
print(" • The migration script needs to be updated for this policy type")
|
||||
print()
|
||||
print("What to do next:")
|
||||
print(" 1. Test your migrated model carefully to ensure it works as expected")
|
||||
print(" 2. If you encounter issues, please open an issue at:")
|
||||
print(" https://github.com/huggingface/lerobot/issues")
|
||||
print(" 3. Include this migration log and the missing keys listed above")
|
||||
print()
|
||||
print("If the model works correctly despite these warnings, the missing keys")
|
||||
print("might be expected for your policy type and can be added to the whitelist.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str | None = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
|
||||
"""
|
||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||
|
||||
@@ -347,13 +442,12 @@ def load_model_from_hub(
|
||||
|
||||
Returns:
|
||||
A tuple containing the model's state dictionary, the policy configuration,
|
||||
and the training configuration.
|
||||
and the training configuration (None if train_config.json is not found).
|
||||
"""
|
||||
# Download files.
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
@@ -362,8 +456,14 @@ def load_model_from_hub(
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
# Try to load train_config (optional)
|
||||
train_config = None
|
||||
try:
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print("train_config.json not found - continuing without training configuration")
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
@@ -409,8 +509,15 @@ def main():
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
# Try to load train_config (optional)
|
||||
train_config = None
|
||||
train_config_path = os.path.join(args.pretrained_path, "train_config.json")
|
||||
if os.path.exists(train_config_path):
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
print("train_config.json not found - continuing without training configuration")
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
@@ -487,10 +594,20 @@ def main():
|
||||
policy_class = get_policy_class(policy_type)
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
# Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
|
||||
known_missing_keys_whitelist = {
|
||||
"pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
|
||||
# Add other policy types and their known missing keys here as needed
|
||||
}
|
||||
|
||||
# Load state dict with graceful missing key handling
|
||||
problematic_missing_keys = load_state_dict_with_missing_key_handling(
|
||||
policy=policy,
|
||||
state_dict=new_state_dict,
|
||||
policy_type=policy_type,
|
||||
known_missing_keys_whitelist=known_missing_keys_whitelist,
|
||||
)
|
||||
policy.to(torch.float32)
|
||||
# Create preprocessor and postprocessor using the factory
|
||||
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
|
||||
@@ -520,7 +637,9 @@ def main():
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
dataset_repo_id = "unknown"
|
||||
if train_config is not None:
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
@@ -641,6 +760,9 @@ final_action = postprocessor(action)
|
||||
else:
|
||||
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
# Display final summary about any problematic missing keys
|
||||
display_migration_summary_with_warnings(problematic_missing_keys)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user