mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
Add OpenPi, Pi0 and Pi0.5 (#1910)
* initial commit * change device in test * do detailed import * adhere to python 3.11 syntax * fix autodocstring * additionally * do same in other files * add model. prefix to all keys in state dict * use dummy stats * add pi05 * also shorten action_steps * fix test * all test pass! and fix tokenizer max length between 05 and 0 * remove test * fix transformer dependency * fix test * split pi0 and pi05 policy in seperate files * fix test * fix push to hub test * add some comments, license and readme * remove warning in config * add pi05 to factory * remove check * rename action_horizon to chunk_size * clean up padding of state and action (more in line with lerobot pi0) * add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0 * fix key match from pytorch state dict (similar keys to openpi implementation now) * also for pi05 * update to python 3.11 * revert to openpi transformer replace python 3.11 * fix(modeling pi0): nit warning message * use safeauto_docstring * fix: remove unused param * fix from pretrained * add preprocess tests * also compile forward method * Do not add model prefix to normalization * use same name for action and state dim as lerobot pi0 and remove fixed image keys * load from pretrained_path * temp: hardcode base model * fix override self.pretrained_path = None overwrite * rename to loss * remove additional image augmentations, lerobot dataset already does this * Add docs * put tests in test folder * Add test to instatiate all base models * go back to python 3.10 * update docs * adapt docs pi05 * change docs: finetune base model options * minor docs fixes and dependencies * remove todo * cast float64 to float32 for mps * skip if no transformers * fix tests * add new models to modelcard * add back init * fix circular input * feat: only run pi test on GPU * remove require_nightly_gpu * replace decorator test_pi0_openpi * rename action_dim, state_dim to max_action_dim, max_state_dim * fix doc and constants * cleanup tests * fix from pretrained * fix tests * add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests * fix, state is included in language not in flow head * Move test to specific folder * and paligemma task with newline * remove add_special_tokens, not needed * feedback pr * Remove previous pi0 and rename pi0_openpi and pi05_openpi * Add Quantile stats to LeRobotDataset (#1985) * - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin * rename pi0/pi05 files * Remove open pi patch and use custom transformer branch for now * renaming * fix * Revert "fix" This reverts commit1ea65730ac. * fix naming * feet(pi0/pi0.5): add pipeline (#2009) * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * refactor(pi05): update imports and rename configuration classes - Changed imports to reflect the new naming convention for PI05 configuration and policy classes. - Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency. - Introduced a new processor file for PI05, implementing pre-processing and post-processing steps. - Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase. * update(pi05): increase tokenizer_max_length for improved processing - Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences. - This adjustment aims to improve the overall performance and flexibility of the PI05 configuration. * add default for state (max_state_dim) * correct naming * fix import * cleanup code * remove unused test * us quantiles for action * move to device * remove discrete state assert * fix pi05 test * move pi05 to device * use base models in comparison tests * small renames for tests * change number of tokens pi05 test * fix openpi tokenization in test * fix hub test * fix test * assert lerobot vs openpi tests --------- Co-authored-by: Pepijn <pepijn@huggingface.co> * add headers * add back previously removed imports * update if statement load processor with dataset stats * remove to avoid circular import * inject dataset stats for pretrained models * check normalization before applying * add link to quantile augument script * fix(policies): transformers import for ci in PI0 & PI05 (#2039) * fix(policies): transformers import for ci in PI0 * fix(policies): transformers import for ci in PI05 * test(processor): fix expected raise when normalization types are missing (#2040) * switch normalization order pipeline for pi05 * Fix/quantiles script (#2064) * refactor augment stats with quantiles script add parallelization for faster processing shift the quantile normalization between -1 1 * fix replay buffer tests * fix comment * overwrite the pipeline normalization features with the policy features * remove double normalization overwrite * cleanup from pretrained * remove typo * also set norm_map * fix(augment_quantiles) images incorrectly divided by 255 * clamp quantiles * link to lerobot base models * rename tests * encorperate PR feedback * update docstring for RunningQuantileStats * update doc links * Revert "clamp quantiles" This reverts commit172207471c. * fix self.paligemma * fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1] * fix libero doc and use different transformer branch * use fix branch instead of feat * update results libero * add new line * fix formatting * precommit * update results libero * update libero doc * update title * final changes * add quantiles to test * run pre commit --------- Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -303,6 +303,65 @@ def clean_state_dict(
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_state_dict_with_missing_key_handling(
|
||||
policy: torch.nn.Module,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
policy_type: str,
|
||||
known_missing_keys_whitelist: dict[str, list[str]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Load state dict into policy with graceful handling of missing keys.
|
||||
|
||||
This function loads the state dict with strict=False, filters out whitelisted
|
||||
missing keys, and provides detailed reporting about any issues found.
|
||||
|
||||
Args:
|
||||
policy: The policy model to load the state dict into.
|
||||
state_dict: The cleaned state dictionary to load.
|
||||
policy_type: The type of policy (used for whitelist lookup).
|
||||
known_missing_keys_whitelist: Dictionary mapping policy types to lists of
|
||||
known acceptable missing keys.
|
||||
|
||||
Returns:
|
||||
List of problematic missing keys that weren't in the whitelist.
|
||||
"""
|
||||
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
|
||||
load_result = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Check for missing keys
|
||||
missing_keys = load_result.missing_keys
|
||||
unexpected_keys = load_result.unexpected_keys
|
||||
|
||||
# Filter out whitelisted missing keys
|
||||
policy_type_lower = policy_type.lower()
|
||||
whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
|
||||
problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
|
||||
|
||||
if missing_keys:
|
||||
if problematic_missing_keys:
|
||||
print(f"WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
|
||||
for key in problematic_missing_keys:
|
||||
print(f" - {key}")
|
||||
|
||||
if len(missing_keys) > len(problematic_missing_keys):
|
||||
whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
|
||||
print(f"INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
|
||||
for key in whitelisted_missing:
|
||||
print(f" - {key}")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"WARNING: Found {len(unexpected_keys)} unexpected keys:")
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("Successfully loaded cleaned state dict into policy model (all keys matched)")
|
||||
else:
|
||||
print("State dict loaded with some missing/unexpected keys (see details above)")
|
||||
|
||||
return problematic_missing_keys
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||
@@ -336,9 +395,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
return converted_features
|
||||
|
||||
|
||||
def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
|
||||
"""
|
||||
Display final migration summary with warnings about problematic missing keys.
|
||||
|
||||
Args:
|
||||
problematic_missing_keys: List of missing keys that weren't in the whitelist.
|
||||
"""
|
||||
if not problematic_missing_keys:
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
|
||||
)
|
||||
print()
|
||||
for key in problematic_missing_keys:
|
||||
print(f" - {key}")
|
||||
print()
|
||||
print("These missing keys may indicate:")
|
||||
print(" • The model architecture has changed")
|
||||
print(" • Some components were not properly saved in the original model")
|
||||
print(" • The migration script needs to be updated for this policy type")
|
||||
print()
|
||||
print("What to do next:")
|
||||
print(" 1. Test your migrated model carefully to ensure it works as expected")
|
||||
print(" 2. If you encounter issues, please open an issue at:")
|
||||
print(" https://github.com/huggingface/lerobot/issues")
|
||||
print(" 3. Include this migration log and the missing keys listed above")
|
||||
print()
|
||||
print("If the model works correctly despite these warnings, the missing keys")
|
||||
print("might be expected for your policy type and can be added to the whitelist.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str | None = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
|
||||
"""
|
||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||
|
||||
@@ -348,13 +443,12 @@ def load_model_from_hub(
|
||||
|
||||
Returns:
|
||||
A tuple containing the model's state dictionary, the policy configuration,
|
||||
and the training configuration.
|
||||
and the training configuration (None if train_config.json is not found).
|
||||
"""
|
||||
# Download files.
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
@@ -363,8 +457,14 @@ def load_model_from_hub(
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
# Try to load train_config (optional)
|
||||
train_config = None
|
||||
try:
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print("train_config.json not found - continuing without training configuration")
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
@@ -410,8 +510,15 @@ def main():
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
# Try to load train_config (optional)
|
||||
train_config = None
|
||||
train_config_path = os.path.join(args.pretrained_path, "train_config.json")
|
||||
if os.path.exists(train_config_path):
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
print("train_config.json not found - continuing without training configuration")
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
@@ -488,10 +595,20 @@ def main():
|
||||
policy_class = get_policy_class(policy_type)
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
# Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
|
||||
known_missing_keys_whitelist = {
|
||||
"pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
|
||||
# Add other policy types and their known missing keys here as needed
|
||||
}
|
||||
|
||||
# Load state dict with graceful missing key handling
|
||||
problematic_missing_keys = load_state_dict_with_missing_key_handling(
|
||||
policy=policy,
|
||||
state_dict=new_state_dict,
|
||||
policy_type=policy_type,
|
||||
known_missing_keys_whitelist=known_missing_keys_whitelist,
|
||||
)
|
||||
policy.to(torch.float32)
|
||||
# Create preprocessor and postprocessor using the factory
|
||||
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
|
||||
@@ -521,7 +638,9 @@ def main():
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
dataset_repo_id = "unknown"
|
||||
if train_config is not None:
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
@@ -552,25 +671,25 @@ def main():
|
||||
|
||||
if create_pr:
|
||||
# Separate commit description for PR body
|
||||
commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline**
|
||||
commit_description = """**Automated Policy Migration to PolicyProcessorPipeline**
|
||||
|
||||
This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture.
|
||||
|
||||
## What Changed
|
||||
|
||||
### ✨ **New Architecture - PolicyProcessorPipeline**
|
||||
### **New Architecture - PolicyProcessorPipeline**
|
||||
Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides:
|
||||
- **Modularity**: Separate preprocessing and postprocessing pipelines
|
||||
- **Flexibility**: Easy to swap, configure, and debug processing steps
|
||||
- **Compatibility**: Works with the latest LeRobot ecosystem
|
||||
|
||||
### 🔧 **Normalization Extraction**
|
||||
### **Normalization Extraction**
|
||||
We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers:
|
||||
- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*`
|
||||
- **Statistics preserved**: Mean, std, min, max values for all features
|
||||
- **Clean model**: State dict now contains only core model weights
|
||||
|
||||
### 📦 **Files Added**
|
||||
### **Files Added**
|
||||
- **preprocessor_config.json**: Configuration for input preprocessing pipeline
|
||||
- **postprocessor_config.json**: Configuration for output postprocessing pipeline
|
||||
- **model.safetensors**: Clean model weights without normalization layers
|
||||
@@ -578,13 +697,13 @@ We've extracted normalization statistics from your model's state_dict and remove
|
||||
- **train_config.json**: Training configuration
|
||||
- **README.md**: Updated model card with migration information
|
||||
|
||||
### 🚀 **Benefits**
|
||||
### **Benefits**
|
||||
- **Backward Compatible**: Your model behavior remains identical
|
||||
- **Future Ready**: Compatible with latest LeRobot features and updates
|
||||
- **Debuggable**: Easy to inspect and modify processing steps
|
||||
- **Portable**: Processors can be shared and reused across models
|
||||
|
||||
### 💻 **Usage**
|
||||
### **Usage**
|
||||
```python
|
||||
# Load your migrated model
|
||||
from lerobot.policies import get_policy_class
|
||||
@@ -642,6 +761,9 @@ final_action = postprocessor(action)
|
||||
else:
|
||||
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
# Display final summary about any problematic missing keys
|
||||
display_migration_summary_with_warnings(problematic_missing_keys)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user