#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generic script to migrate any policy model with normalization layers to the new pipeline-based system. This script: 1. Loads an existing pretrained policy model 2. Extracts normalization statistics from the model 3. Creates both preprocessor and postprocessor: - Preprocessor: normalizes both inputs (observations) and outputs (actions) for training - Postprocessor: unnormalizes outputs (actions) for inference 4. Removes normalization layers from the model state_dict 5. Saves the new model and both processors Usage: python src/lerobot/processor/migrate_policy_normalization.py \ --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ --policy-type act \ --push-to-hub """ import argparse import importlib import json import os from copy import deepcopy from pathlib import Path from typing import Any import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file as load_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from .batch_processor import ToBatchProcessor from .device_processor import DeviceProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .pipeline import RobotProcessor from .rename_processor import RenameProcessor # Policy type to class mapping POLICY_CLASSES = { "act": "lerobot.policies.act.modeling_act.ACTPolicy", "diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy", "pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy", "pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy", "smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy", "tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy", "vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy", "sac": "lerobot.policies.sac.modeling_sac.SACPolicy", "classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy", } def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: """Extract normalization statistics from model state_dict.""" stats = {} # Define patterns to match and their prefixes to remove normalization_patterns = [ "normalize_inputs.buffer_", "unnormalize_outputs.buffer_", "normalize_targets.buffer_", "normalize.", # Must come after normalize_* patterns "unnormalize.", # Must come after unnormalize_* patterns "input_normalizer.", "output_normalizer.", ] # Process each key in state_dict for key, tensor in state_dict.items(): # Try each pattern for pattern in normalization_patterns: if key.startswith(pattern): # Extract the remaining part after the pattern remaining = key[len(pattern) :] parts = remaining.split(".") # Need at least feature name and stat type if len(parts) >= 2: # Last part is the stat type (mean, std, min, max, etc.) stat_type = parts[-1] # Everything else is the feature name feature_name = ".".join(parts[:-1]).replace("_", ".") # Add to stats if feature_name not in stats: stats[feature_name] = {} stats[feature_name][stat_type] = tensor.clone() # Only process the first matching pattern break return stats def detect_features_and_norm_modes( config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]] ) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]: """Detect features and normalization modes from config and stats.""" features = {} norm_modes = {} # First, check if there's a normalization_mapping in the config if "normalization_mapping" in config: print(f"Found normalization_mapping in config: {config['normalization_mapping']}") # Extract normalization modes from config for feature_name, mode_str in config["normalization_mapping"].items(): # Convert string to NormalizationMode enum if mode_str == "mean_std": mode = NormalizationMode.MEAN_STD elif mode_str == "min_max": mode = NormalizationMode.MIN_MAX else: print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'") continue # Determine feature type from feature name if "image" in feature_name or "visual" in feature_name: feature_type = FeatureType.VISUAL elif "state" in feature_name: feature_type = FeatureType.STATE elif "action" in feature_name: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE norm_modes[feature_type] = mode # Try to extract from config if "features" in config: for key, feature_config in config["features"].items(): shape = feature_config.get("shape", feature_config.get("dim")) shape = (shape,) if isinstance(shape, int) else tuple(shape) # Determine feature type if "image" in key or "visual" in key: feature_type = FeatureType.VISUAL elif "state" in key: feature_type = FeatureType.STATE elif "action" in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE # Default features[key] = PolicyFeature(feature_type, shape) # If no features in config, infer from stats if not features: for key, stat_dict in stats.items(): # Get shape from any stat tensor tensor = next(iter(stat_dict.values())) shape = tuple(tensor.shape) # Determine feature type based on key if "image" in key or "visual" in key or "pixels" in key: feature_type = FeatureType.VISUAL elif "state" in key or "joint" in key or "position" in key: feature_type = FeatureType.STATE elif "action" in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE features[key] = PolicyFeature(feature_type, shape) # If normalization modes weren't in config, determine based on available stats if not norm_modes: for key, stat_dict in stats.items(): if key in features: if "mean" in stat_dict and "std" in stat_dict: feature_type = features[key].type if feature_type not in norm_modes: norm_modes[feature_type] = NormalizationMode.MEAN_STD elif "min" in stat_dict and "max" in stat_dict: feature_type = features[key].type if feature_type not in norm_modes: norm_modes[feature_type] = NormalizationMode.MIN_MAX # Default normalization modes if not detected if FeatureType.VISUAL not in norm_modes: norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD if FeatureType.STATE not in norm_modes: norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX if FeatureType.ACTION not in norm_modes: norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD return features, norm_modes def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Remove normalization layers from state_dict.""" new_state_dict = {} # Patterns to remove remove_patterns = [ "normalize_inputs.", "unnormalize_outputs.", "normalize_targets.", # Added pattern for target normalization "normalize.", "unnormalize.", "input_normalizer.", "output_normalizer.", "normalizer.", ] for key, tensor in state_dict.items(): should_remove = any(pattern in key for pattern in remove_patterns) if not should_remove: new_state_dict[key] = tensor return new_state_dict def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: """Convert features from old format to PolicyFeature objects.""" converted_features = {} for key, feature_dict in features_dict.items(): # Determine feature type based on key if "image" in key or "visual" in key: feature_type = FeatureType.VISUAL elif "state" in key: feature_type = FeatureType.STATE elif "action" in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE # Get shape from feature dict shape = feature_dict.get("shape", feature_dict.get("dim")) shape = (shape,) if isinstance(shape, int) else tuple(shape) converted_features[key] = PolicyFeature(feature_type, shape) return converted_features def load_model_from_hub( repo_id: str, revision: str = None ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: """Load model state_dict and config from hub.""" # Download files safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) # Load state_dict state_dict = load_safetensors(safetensors_path) # Load config with open(config_path) as f: config = json.load(f) with open(train_config_path) as f: train_config = json.load(f) return state_dict, config, train_config def main(): parser = argparse.ArgumentParser( description="Migrate policy models with normalization layers to new pipeline system" ) parser.add_argument( "--pretrained-path", type=str, required=True, help="Path to pretrained model (hub repo or local directory)", ) parser.add_argument( "--output-dir", type=str, default=None, help="Output directory for migrated model (default: same as pretrained-path)", ) parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub") parser.add_argument( "--hub-repo-id", type=str, default=None, help="Hub repository ID for pushing (default: same as pretrained-path)", ) parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load") parser.add_argument("--private", action="store_true", help="Make the hub repository private") args = parser.parse_args() # Load model and config print(f"Loading model from {args.pretrained_path}...") if os.path.isdir(args.pretrained_path): # Local directory state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors")) with open(os.path.join(args.pretrained_path, "config.json")) as f: config = json.load(f) with open(os.path.join(args.pretrained_path, "train_config.json")) as f: train_config = json.load(f) else: # Hub repository state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision) # Extract normalization statistics print("Extracting normalization statistics...") stats = extract_normalization_stats(state_dict) print(f"Found normalization statistics for: {list(stats.keys())}") # Detect input features and normalization modes print("Detecting features and normalization modes...") features, norm_map = detect_features_and_norm_modes(config, stats) print(f"Detected features: {list(features.keys())}") print(f"Normalization modes: {norm_map}") # Remove normalization layers from state_dict print("Removing normalization layers from model...") new_state_dict = remove_normalization_layers(state_dict) removed_keys = set(state_dict.keys()) - set(new_state_dict.keys()) if removed_keys: print(f"Removed {len(removed_keys)} normalization layer keys") # Determine output path if args.output_dir: output_dir = Path(args.output_dir) else: if os.path.isdir(args.pretrained_path): output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated" else: output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated") output_dir.mkdir(parents=True, exist_ok=True) # Clean up config - remove normalization_mapping field cleaned_config = dict(config) if "normalization_mapping" in cleaned_config: print("Removing 'normalization_mapping' field from config") del cleaned_config["normalization_mapping"] policy_type = deepcopy(cleaned_config["type"]) del cleaned_config["type"] # Instantiate the policy model with cleaned config and load the cleaned state dict print(f"Instantiating {policy_type} policy model...") policy_class_path = POLICY_CLASSES[policy_type] module_path, class_name = policy_class_path.rsplit(".", 1) module = importlib.import_module(module_path) policy_class = getattr(module, class_name) # Create config class instance config_module_path = module_path.replace("modeling", "configuration") config_module = importlib.import_module(config_module_path) # Handle special cases for config class names config_class_names = { "act": "ACTConfig", "diffusion": "DiffusionConfig", "pi0": "PI0Config", "pi0fast": "PI0FASTConfig", "smolvla": "SmolVLAConfig", "tdmpc": "TDMPCConfig", "vqbet": "VQBeTConfig", "sac": "SACConfig", "classifier": "ClassifierConfig", } config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config") config_class = getattr(config_module, config_class_name) # Convert input_features and output_features to PolicyFeature objects - these are mandatory if "input_features" not in cleaned_config: raise ValueError("Missing mandatory 'input_features' in config") if "output_features" not in cleaned_config: raise ValueError("Missing mandatory 'output_features' in config") cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"]) cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"]) # Create config instance from cleaned config dict policy_config = config_class(**cleaned_config) # Create policy instance - some policies expect dataset_stats policy = policy_class(policy_config) # Load the cleaned state dict policy.load_state_dict(new_state_dict, strict=True) print("Successfully loaded cleaned state dict into policy model") # Now create preprocessor and postprocessor with cleaned_config available print("Creating preprocessor and postprocessor...") # The pattern from existing processor factories: # - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features # - Postprocessor has one UnnormalizerProcessor for output_features only # Get features from cleaned_config (now they're PolicyFeature objects) input_features = cleaned_config.get("input_features", {}) output_features = cleaned_config.get("output_features", {}) # Create preprocessor with two normalizers (following the pattern from processor factories) preprocessor_steps = [ RenameProcessor(rename_map={}), NormalizerProcessor( features={**input_features, **output_features}, norm_map=norm_map, stats=stats, ), ToBatchProcessor(), DeviceProcessor(device=policy_config.device), ] preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor") # Create postprocessor with unnormalizer for outputs only postprocessor_steps = [ DeviceProcessor(device="cpu"), UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats), ] postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor") # Determine hub repo ID if pushing to hub if args.push_to_hub: if args.hub_repo_id: hub_repo_id = args.hub_repo_id else: if not os.path.isdir(args.pretrained_path): # Use same repo with "_migrated" suffix hub_repo_id = f"{args.pretrained_path}_migrated" else: raise ValueError("--hub-repo-id must be specified when pushing local model to hub") else: hub_repo_id = None # Save preprocessor and postprocessor to root directory print(f"Saving preprocessor to {output_dir}...") preprocessor.save_pretrained(output_dir) if args.push_to_hub: preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) print(f"Saving postprocessor to {output_dir}...") postprocessor.save_pretrained(output_dir) if args.push_to_hub: postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private) # Save model using the policy's save_pretrained method print(f"Saving model to {output_dir}...") policy.save_pretrained( output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private ) # Generate and save model card print("Generating model card...") # Get metadata from original config dataset_repo_id = train_config.get("repo_id", "unknown") license = config.get("license", "apache-2.0") tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] tags = set(tags).union({"robotics", "lerobot", policy_type}) tags = list(tags) # Generate model card card = policy.generate_model_card( dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags ) # Save model card locally card.save(str(output_dir / "README.md")) print(f"Model card saved to {output_dir / 'README.md'}") # Push model card to hub if requested if args.push_to_hub: from huggingface_hub import HfApi api = HfApi() api.upload_file( path_or_fileobj=str(output_dir / "README.md"), path_in_repo="README.md", repo_id=hub_repo_id, repo_type="model", commit_message="Add model card for migrated model", ) print("Model card pushed to hub") print("\nMigration complete!") print(f"Migrated model saved to: {output_dir}") if args.push_to_hub: print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}") if __name__ == "__main__": main()