mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
Compare commits
3 Commits
docs/compl
...
docs/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
99c0d93b34 | ||
|
|
c62784e14c | ||
|
|
cc6a2cac43 |
@@ -1,172 +0,0 @@
|
|||||||
- sections:
|
|
||||||
- local: index
|
|
||||||
title: LeRobot
|
|
||||||
- local: installation
|
|
||||||
title: Installation
|
|
||||||
- local: cheat-sheet
|
|
||||||
title: Cheat sheet
|
|
||||||
title: Get started
|
|
||||||
- sections:
|
|
||||||
- local: il_robots
|
|
||||||
title: Imitation Learning for Robots
|
|
||||||
- local: bring_your_own_policies
|
|
||||||
title: Adding a Policy
|
|
||||||
- local: integrate_hardware
|
|
||||||
title: Bring Your Own Hardware
|
|
||||||
- local: hilserl
|
|
||||||
title: Train a Robot with RL
|
|
||||||
- local: hilserl_sim
|
|
||||||
title: Train RL in Simulation
|
|
||||||
- local: multi_gpu_training
|
|
||||||
title: Multi GPU training
|
|
||||||
- local: hil_data_collection
|
|
||||||
title: Human In the Loop Data Collection
|
|
||||||
- local: peft_training
|
|
||||||
title: Training with PEFT (e.g., LoRA)
|
|
||||||
- local: rename_map
|
|
||||||
title: Using Rename Map and Empty Cameras
|
|
||||||
title: "Tutorials"
|
|
||||||
- sections:
|
|
||||||
- local: hardware_guide
|
|
||||||
title: Compute Hardware Guide
|
|
||||||
- local: torch_accelerators
|
|
||||||
title: PyTorch accelerators
|
|
||||||
title: "Compute & Hardware"
|
|
||||||
- sections:
|
|
||||||
- local: lerobot-dataset-v3
|
|
||||||
title: Using LeRobotDataset
|
|
||||||
- local: porting_datasets_v3
|
|
||||||
title: Porting Large Datasets
|
|
||||||
- local: using_dataset_tools
|
|
||||||
title: Using the Dataset Tools
|
|
||||||
- local: language_and_recipes
|
|
||||||
title: Language Columns and Recipes
|
|
||||||
- local: tools
|
|
||||||
title: Tools
|
|
||||||
- local: video_encoding_parameters
|
|
||||||
title: Video encoding parameters
|
|
||||||
- local: streaming_video_encoding
|
|
||||||
title: Streaming Video Encoding
|
|
||||||
title: "Datasets"
|
|
||||||
- sections:
|
|
||||||
- local: act
|
|
||||||
title: ACT
|
|
||||||
- local: smolvla
|
|
||||||
title: SmolVLA
|
|
||||||
- local: pi0
|
|
||||||
title: π₀ (Pi0)
|
|
||||||
- local: pi0fast
|
|
||||||
title: π₀-FAST (Pi0Fast)
|
|
||||||
- local: pi05
|
|
||||||
title: π₀.₅ (Pi05)
|
|
||||||
- local: eo1
|
|
||||||
title: EO-1
|
|
||||||
- local: groot
|
|
||||||
title: NVIDIA GR00T N1.5
|
|
||||||
- local: xvla
|
|
||||||
title: X-VLA
|
|
||||||
- local: multi_task_dit
|
|
||||||
title: Multitask DiT Policy
|
|
||||||
- local: walloss
|
|
||||||
title: WALL-OSS
|
|
||||||
title: "Policies"
|
|
||||||
- sections:
|
|
||||||
- local: sarm
|
|
||||||
title: SARM
|
|
||||||
title: "Reward Models"
|
|
||||||
- sections:
|
|
||||||
- local: inference
|
|
||||||
title: Policy Deployment (lerobot-rollout)
|
|
||||||
- local: async
|
|
||||||
title: Use Async Inference
|
|
||||||
- local: rtc
|
|
||||||
title: Real-Time Chunking (RTC)
|
|
||||||
title: "Inference"
|
|
||||||
- sections:
|
|
||||||
- local: envhub
|
|
||||||
title: Environments from the Hub
|
|
||||||
- local: envhub_leisaac
|
|
||||||
title: Control & Train Robots in Sim (LeIsaac)
|
|
||||||
title: "Simulation"
|
|
||||||
- sections:
|
|
||||||
- local: adding_benchmarks
|
|
||||||
title: Adding a New Benchmark
|
|
||||||
- local: libero
|
|
||||||
title: LIBERO
|
|
||||||
- local: libero_plus
|
|
||||||
title: LIBERO-plus
|
|
||||||
- local: metaworld
|
|
||||||
title: Meta-World
|
|
||||||
- local: robotwin
|
|
||||||
title: RoboTwin 2.0
|
|
||||||
- local: robocasa
|
|
||||||
title: RoboCasa365
|
|
||||||
- local: robocerebra
|
|
||||||
title: RoboCerebra
|
|
||||||
- local: robomme
|
|
||||||
title: RoboMME
|
|
||||||
- local: envhub_isaaclab_arena
|
|
||||||
title: NVIDIA IsaacLab Arena Environments
|
|
||||||
- local: vlabench
|
|
||||||
title: VLABench
|
|
||||||
title: "Benchmarks"
|
|
||||||
- sections:
|
|
||||||
- local: introduction_processors
|
|
||||||
title: Introduction to Robot Processors
|
|
||||||
- local: debug_processor_pipeline
|
|
||||||
title: Debug your processor pipeline
|
|
||||||
- local: implement_your_own_processor
|
|
||||||
title: Implement your own processor
|
|
||||||
- local: processors_robots_teleop
|
|
||||||
title: Processors for Robots and Teleoperators
|
|
||||||
- local: env_processor
|
|
||||||
title: Environment Processors
|
|
||||||
- local: action_representations
|
|
||||||
title: Action Representations
|
|
||||||
title: "Robot Processors"
|
|
||||||
- sections:
|
|
||||||
- local: so101
|
|
||||||
title: SO-101
|
|
||||||
- local: so100
|
|
||||||
title: SO-100
|
|
||||||
- local: koch
|
|
||||||
title: Koch v1.1
|
|
||||||
- local: lekiwi
|
|
||||||
title: LeKiwi
|
|
||||||
- local: hope_jr
|
|
||||||
title: Hope Jr
|
|
||||||
- local: reachy2
|
|
||||||
title: Reachy 2
|
|
||||||
- local: unitree_g1
|
|
||||||
title: Unitree G1
|
|
||||||
- local: earthrover_mini_plus
|
|
||||||
title: Earth Rover Mini
|
|
||||||
- local: omx
|
|
||||||
title: OMX
|
|
||||||
- local: openarm
|
|
||||||
title: OpenArm
|
|
||||||
- local: rebot_b601
|
|
||||||
title: reBot B601-DM
|
|
||||||
title: "Robots"
|
|
||||||
- sections:
|
|
||||||
- local: phone_teleop
|
|
||||||
title: Phone
|
|
||||||
title: "Teleoperators"
|
|
||||||
- sections:
|
|
||||||
- local: cameras
|
|
||||||
title: Cameras
|
|
||||||
title: "Sensors"
|
|
||||||
- sections:
|
|
||||||
- local: notebooks
|
|
||||||
title: Notebooks
|
|
||||||
- local: feetech
|
|
||||||
title: Updating Feetech Firmware
|
|
||||||
- local: damiao
|
|
||||||
title: Damiao Motors and CAN Bus
|
|
||||||
title: "Resources"
|
|
||||||
- sections:
|
|
||||||
- local: contributing
|
|
||||||
title: Contribute to LeRobot
|
|
||||||
- local: backwardcomp
|
|
||||||
title: Backward compatibility
|
|
||||||
title: "About"
|
|
||||||
@@ -1,214 +1,172 @@
|
|||||||
# LeRobot documentation table of contents
|
|
||||||
#
|
|
||||||
# Ordering principle: gentle onboarding first, advanced/custom work last.
|
|
||||||
# Within each top-level section the same rule applies — concept/overview pages
|
|
||||||
# before reference/per-item pages.
|
|
||||||
#
|
|
||||||
# Pages marked "NEW (to create)" do not yet exist as .mdx files; they are
|
|
||||||
# placeholders for the redesign and must be authored before the docs build.
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: index
|
- local: index
|
||||||
title: 🤗 LeRobot
|
title: LeRobot
|
||||||
- local: quickstart # NEW (to create) — 15-min zero-to-trained-ACT path
|
|
||||||
title: Quickstart
|
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation
|
title: Installation
|
||||||
- local: core_concepts # NEW (to create) — datasets, policies, processors, robots, envs in one mental model
|
|
||||||
title: Core concepts
|
|
||||||
- local: cheat-sheet
|
- local: cheat-sheet
|
||||||
title: Command cheat sheet
|
title: Cheat sheet
|
||||||
title: Get started
|
title: Get started
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: il_robots
|
- local: il_robots
|
||||||
title: Imitation learning end-to-end
|
title: Imitation Learning for Robots
|
||||||
|
- local: bring_your_own_policies
|
||||||
|
title: Adding a Policy
|
||||||
|
- local: integrate_hardware
|
||||||
|
title: Bring Your Own Hardware
|
||||||
|
- local: hilserl
|
||||||
|
title: Train a Robot with RL
|
||||||
|
- local: hilserl_sim
|
||||||
|
title: Train RL in Simulation
|
||||||
|
- local: multi_gpu_training
|
||||||
|
title: Multi GPU training
|
||||||
- local: hil_data_collection
|
- local: hil_data_collection
|
||||||
title: Human-in-the-loop data collection
|
title: Human In the Loop Data Collection
|
||||||
- local: inference
|
- local: peft_training
|
||||||
title: Deploying a trained policy
|
title: Training with PEFT (e.g., LoRA)
|
||||||
- local: rename_map
|
- local: rename_map
|
||||||
title: Matching dataset keys to a policy (rename map)
|
title: Using Rename Map and Empty Cameras
|
||||||
title: Your first project
|
title: "Tutorials"
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: hardware_guide
|
- local: hardware_guide
|
||||||
title: Compute hardware guide
|
title: Compute Hardware Guide
|
||||||
- local: torch_accelerators
|
- local: torch_accelerators
|
||||||
title: PyTorch accelerators
|
title: PyTorch accelerators
|
||||||
- local: multi_gpu_training
|
title: "Compute & Hardware"
|
||||||
title: Multi-GPU training
|
|
||||||
- local: peft_training
|
|
||||||
title: Parameter-efficient fine-tuning (LoRA)
|
|
||||||
title: Training
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
title: Using LeRobotDataset
|
title: Using LeRobotDataset
|
||||||
|
- local: porting_datasets_v3
|
||||||
|
title: Porting Large Datasets
|
||||||
- local: using_dataset_tools
|
- local: using_dataset_tools
|
||||||
title: Dataset tools
|
title: Using the Dataset Tools
|
||||||
- local: language_and_recipes
|
- local: language_and_recipes
|
||||||
title: Language columns & recipes
|
title: Language Columns and Recipes
|
||||||
- local: tools
|
- local: tools
|
||||||
title: Tool calls in datasets
|
title: Tools
|
||||||
- local: video_encoding_parameters
|
- local: video_encoding_parameters
|
||||||
title: Video encoding parameters
|
title: Video encoding parameters
|
||||||
- local: streaming_video_encoding
|
- local: streaming_video_encoding
|
||||||
title: Streaming video encoding
|
title: Streaming Video Encoding
|
||||||
- local: porting_datasets_v3
|
title: "Datasets"
|
||||||
title: Porting datasets to v3
|
|
||||||
title: Datasets
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: policies_overview # NEW (to create) — concept hub + "choose a policy" decision guide
|
- local: act
|
||||||
title: Choosing a policy
|
title: ACT
|
||||||
- sections:
|
- local: smolvla
|
||||||
- local: act
|
title: SmolVLA
|
||||||
title: ACT
|
- local: pi0
|
||||||
- local: smolvla
|
title: π₀ (Pi0)
|
||||||
title: SmolVLA
|
- local: pi0fast
|
||||||
- local: pi0
|
title: π₀-FAST (Pi0Fast)
|
||||||
title: π₀ (Pi0)
|
- local: pi05
|
||||||
- local: pi0fast
|
title: π₀.₅ (Pi05)
|
||||||
title: π₀-FAST
|
- local: eo1
|
||||||
- local: pi05
|
title: EO-1
|
||||||
title: π₀.₅ (Pi05)
|
- local: groot
|
||||||
- local: eo1
|
title: NVIDIA GR00T N1.5
|
||||||
title: EO-1
|
- local: xvla
|
||||||
- local: groot
|
title: X-VLA
|
||||||
title: NVIDIA GR00T N1.5
|
- local: multi_task_dit
|
||||||
- local: xvla
|
title: Multitask DiT Policy
|
||||||
title: X-VLA
|
- local: walloss
|
||||||
- local: walloss
|
title: WALL-OSS
|
||||||
title: WALL-OSS
|
title: "Policies"
|
||||||
- local: multi_task_dit
|
|
||||||
title: Multitask DiT
|
|
||||||
title: Policy reference
|
|
||||||
title: Policies
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: async
|
|
||||||
title: Async inference
|
|
||||||
- local: rtc
|
|
||||||
title: Real-time chunking (RTC)
|
|
||||||
title: Real-time deployment
|
|
||||||
|
|
||||||
- sections:
|
|
||||||
- local: hilserl
|
|
||||||
title: Train a robot with RL (HIL-SERL)
|
|
||||||
- local: hilserl_sim
|
|
||||||
title: Train RL in simulation
|
|
||||||
- local: sarm
|
- local: sarm
|
||||||
title: SARM reward model
|
title: SARM
|
||||||
title: Reinforcement learning
|
title: "Reward Models"
|
||||||
|
- sections:
|
||||||
|
- local: inference
|
||||||
|
title: Policy Deployment (lerobot-rollout)
|
||||||
|
- local: async
|
||||||
|
title: Use Async Inference
|
||||||
|
- local: rtc
|
||||||
|
title: Real-Time Chunking (RTC)
|
||||||
|
title: "Inference"
|
||||||
- sections:
|
- sections:
|
||||||
- local: envhub
|
- local: envhub
|
||||||
title: Environments from the Hub
|
title: Environments from the Hub
|
||||||
- local: envhub_leisaac
|
- local: envhub_leisaac
|
||||||
title: LeIsaac — control & train in sim
|
title: Control & Train Robots in Sim (LeIsaac)
|
||||||
|
title: "Simulation"
|
||||||
|
- sections:
|
||||||
|
- local: adding_benchmarks
|
||||||
|
title: Adding a New Benchmark
|
||||||
|
- local: libero
|
||||||
|
title: LIBERO
|
||||||
|
- local: libero_plus
|
||||||
|
title: LIBERO-plus
|
||||||
|
- local: metaworld
|
||||||
|
title: Meta-World
|
||||||
|
- local: robotwin
|
||||||
|
title: RoboTwin 2.0
|
||||||
|
- local: robocasa
|
||||||
|
title: RoboCasa365
|
||||||
|
- local: robocerebra
|
||||||
|
title: RoboCerebra
|
||||||
|
- local: robomme
|
||||||
|
title: RoboMME
|
||||||
- local: envhub_isaaclab_arena
|
- local: envhub_isaaclab_arena
|
||||||
title: NVIDIA IsaacLab Arena environments
|
title: NVIDIA IsaacLab Arena Environments
|
||||||
- sections:
|
- local: vlabench
|
||||||
- local: libero
|
title: VLABench
|
||||||
title: LIBERO
|
title: "Benchmarks"
|
||||||
- local: libero_plus
|
|
||||||
title: LIBERO-plus
|
|
||||||
- local: metaworld
|
|
||||||
title: Meta-World
|
|
||||||
- local: robotwin
|
|
||||||
title: RoboTwin 2.0
|
|
||||||
- local: robocasa
|
|
||||||
title: RoboCasa365
|
|
||||||
- local: robocerebra
|
|
||||||
title: RoboCerebra
|
|
||||||
- local: robomme
|
|
||||||
title: RoboMME
|
|
||||||
- local: vlabench
|
|
||||||
title: VLABench
|
|
||||||
title: Benchmark suites
|
|
||||||
title: Simulation & benchmarks
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: introduction_processors
|
- local: introduction_processors
|
||||||
title: Introduction to processors
|
title: Introduction to Robot Processors
|
||||||
- local: processors_robots_teleop
|
|
||||||
title: Processors for robots & teleoperators
|
|
||||||
- local: env_processor
|
|
||||||
title: Environment processors
|
|
||||||
- local: action_representations
|
|
||||||
title: Action representations
|
|
||||||
- local: debug_processor_pipeline
|
- local: debug_processor_pipeline
|
||||||
title: Debugging a pipeline
|
title: Debug your processor pipeline
|
||||||
- local: implement_your_own_processor
|
- local: implement_your_own_processor
|
||||||
title: Implementing your own processor
|
title: Implement your own processor
|
||||||
title: Processors
|
- local: processors_robots_teleop
|
||||||
|
title: Processors for Robots and Teleoperators
|
||||||
|
- local: env_processor
|
||||||
|
title: Environment Processors
|
||||||
|
- local: action_representations
|
||||||
|
title: Action Representations
|
||||||
|
title: "Robot Processors"
|
||||||
- sections:
|
- sections:
|
||||||
- sections:
|
- local: so101
|
||||||
- local: so101
|
title: SO-101
|
||||||
title: SO-101
|
- local: so100
|
||||||
- local: so100
|
title: SO-100
|
||||||
title: SO-100
|
- local: koch
|
||||||
- local: koch
|
title: Koch v1.1
|
||||||
title: Koch v1.1
|
- local: lekiwi
|
||||||
- local: omx
|
title: LeKiwi
|
||||||
title: OMX
|
- local: hope_jr
|
||||||
- local: openarm
|
title: Hope Jr
|
||||||
title: OpenArm
|
- local: reachy2
|
||||||
title: Low-cost arms
|
title: Reachy 2
|
||||||
- sections:
|
- local: unitree_g1
|
||||||
- local: lekiwi
|
title: Unitree G1
|
||||||
title: LeKiwi
|
- local: earthrover_mini_plus
|
||||||
- local: earthrover_mini_plus
|
title: Earth Rover Mini
|
||||||
title: Earth Rover Mini
|
- local: omx
|
||||||
title: Mobile platforms
|
title: OMX
|
||||||
- sections:
|
- local: openarm
|
||||||
- local: hope_jr
|
title: OpenArm
|
||||||
title: Hope Jr
|
- local: rebot_b601
|
||||||
- local: reachy2
|
title: reBot B601-DM
|
||||||
title: Reachy 2
|
title: "Robots"
|
||||||
- local: unitree_g1
|
- sections:
|
||||||
title: Unitree G1
|
- local: phone_teleop
|
||||||
title: Bimanual & humanoid
|
title: Phone
|
||||||
- sections:
|
title: "Teleoperators"
|
||||||
- local: rebot_b601
|
|
||||||
title: reBot B601-DM
|
|
||||||
title: Research & industrial
|
|
||||||
title: Supported robots
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: cameras
|
- local: cameras
|
||||||
title: Cameras
|
title: Cameras
|
||||||
- local: phone_teleop
|
title: "Sensors"
|
||||||
title: Phone teleoperation
|
|
||||||
- local: feetech
|
|
||||||
title: Feetech firmware update
|
|
||||||
- local: damiao
|
|
||||||
title: Damiao motors & CAN bus
|
|
||||||
title: Sensors, teleop & motors
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: integrate_hardware
|
|
||||||
title: Bring your own hardware
|
|
||||||
- local: bring_your_own_policies
|
|
||||||
title: Add a new policy
|
|
||||||
- local: adding_benchmarks
|
|
||||||
title: Add a new benchmark
|
|
||||||
title: Extend LeRobot
|
|
||||||
|
|
||||||
- sections:
|
|
||||||
- local: troubleshooting # NEW (to create) — common errors: USB, calibration drift, CUDA OOM, video decoding…
|
|
||||||
title: Troubleshooting & FAQ
|
|
||||||
- local: glossary # NEW (to create) — episode, action chunk, leader/follower, teleop, processor…
|
|
||||||
title: Glossary
|
|
||||||
- local: notebooks
|
- local: notebooks
|
||||||
title: Example notebooks
|
title: Notebooks
|
||||||
- local: backwardcomp
|
- local: feetech
|
||||||
title: Backward compatibility
|
title: Updating Feetech Firmware
|
||||||
title: Reference
|
- local: damiao
|
||||||
|
title: Damiao Motors and CAN Bus
|
||||||
|
title: "Resources"
|
||||||
- sections:
|
- sections:
|
||||||
- local: contributing
|
- local: contributing
|
||||||
title: Contributing to LeRobot
|
title: Contribute to LeRobot
|
||||||
title: About
|
- local: backwardcomp
|
||||||
|
title: Backward compatibility
|
||||||
|
title: "About"
|
||||||
|
|||||||
@@ -79,13 +79,17 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
|||||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-rollout \
|
lerobot-record \
|
||||||
--strategy.type=base \
|
--robot.type=so100_follower \
|
||||||
--policy.path=${HF_USER}/act_policy \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=/dev/ttyACM0 \
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.id=my_robot \
|
||||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
--display_data=true \
|
--display_data=true \
|
||||||
--task="Your task description" \ # can be skipped for ACT
|
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||||
--duration=60
|
--dataset.num_episodes=10 \
|
||||||
|
--dataset.single_task="Your task description" \
|
||||||
|
--dataset.streaming_encoding=true \
|
||||||
|
--dataset.encoder_threads=2 \
|
||||||
|
# --dataset.camera_encoder.vcodec=auto \
|
||||||
|
--policy.path=${HF_USER}/act_policy
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -105,12 +105,10 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
|||||||
|
|
||||||
### Evaluate in your hardware setup
|
### Evaluate in your hardware setup
|
||||||
|
|
||||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-rollout\
|
lerobot-record \
|
||||||
--strategy.type=sentry \
|
|
||||||
--strategy.upload_every_n_episodes=5 \
|
|
||||||
--robot.type=bi_so_follower \
|
--robot.type=bi_so_follower \
|
||||||
--robot.left_arm_port=/dev/ttyACM1 \
|
--robot.left_arm_port=/dev/ttyACM1 \
|
||||||
--robot.right_arm_port=/dev/ttyACM0 \
|
--robot.right_arm_port=/dev/ttyACM0 \
|
||||||
@@ -121,12 +119,14 @@ lerobot-rollout\
|
|||||||
}' \
|
}' \
|
||||||
--display_data=true \
|
--display_data=true \
|
||||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||||
|
--dataset.num_episodes=10 \
|
||||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||||
--dataset.streaming_encoding=true \
|
--dataset.streaming_encoding=true \
|
||||||
--dataset.encoder_threads=2 \
|
--dataset.encoder_threads=2 \
|
||||||
# --dataset.camera_encoder.vcodec=auto \
|
# --dataset.camera_encoder.vcodec=auto \
|
||||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||||
--duration=600
|
--dataset.episode_time_s=30 \
|
||||||
|
--dataset.reset_time_s=10
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
|||||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||||
|
|
||||||
robot_config = SO101FollowerConfig(
|
robot_config = SO101FollowerConfig(
|
||||||
port="/dev/tty.usbmodem5AB90687491",
|
port="/dev/tty.usbmodem58760431541",
|
||||||
id="my_follower_arm",
|
id="my_red_robot_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_config = SO101LeaderConfig(
|
teleop_config = SO101LeaderConfig(
|
||||||
port="/dev/tty.usbmodem5AB90689011",
|
port="/dev/tty.usbmodem58760431551",
|
||||||
id="my_leader_arm",
|
id="my_blue_leader_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
robot = SO101Follower(robot_config)
|
robot = SO101Follower(robot_config)
|
||||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
|||||||
<hfoption id="Command">
|
<hfoption id="Command">
|
||||||
```bash
|
```bash
|
||||||
lerobot-teleoperate \
|
lerobot-teleoperate \
|
||||||
--robot.type=so101_follower \
|
--robot.type=koch_follower \
|
||||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||||
--robot.id=my_follower_arm \
|
--robot.id=my_awesome_follower_arm \
|
||||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||||
--teleop.type=so101_leader \
|
--teleop.type=koch_leader \
|
||||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||||
--teleop.id=my_leader_arm \
|
--teleop.id=my_awesome_leader_arm \
|
||||||
--display_data=true
|
--display_data=true
|
||||||
```
|
```
|
||||||
</hfoption>
|
</hfoption>
|
||||||
@@ -122,48 +122,34 @@ lerobot-teleoperate \
|
|||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
```python
|
```python
|
||||||
import time
|
|
||||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
|
||||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||||
|
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||||
|
|
||||||
robot_config = SO101FollowerConfig(
|
camera_config = {
|
||||||
port="/dev/tty.usbmodem5AB90687491",
|
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||||
id="my_follower_arm",
|
}
|
||||||
cameras={
|
|
||||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
robot_config = KochFollowerConfig(
|
||||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
port="/dev/tty.usbmodem585A0076841",
|
||||||
}
|
id="my_red_robot_arm",
|
||||||
|
cameras=camera_config
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_config = SO101LeaderConfig(
|
teleop_config = KochLeaderConfig(
|
||||||
port="/dev/tty.usbmodem5AB90689011",
|
port="/dev/tty.usbmodem58760431551",
|
||||||
id="my_leader_arm",
|
id="my_blue_leader_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
init_rerun(session_name="teleoperation")
|
robot = KochFollower(robot_config)
|
||||||
|
teleop_device = KochLeader(teleop_config)
|
||||||
robot = SO101Follower(robot_config)
|
|
||||||
teleop_device = SO101Leader(teleop_config)
|
|
||||||
robot.connect()
|
robot.connect()
|
||||||
teleop_device.connect()
|
teleop_device.connect()
|
||||||
|
|
||||||
TARGET_HZ = 30
|
|
||||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
observation = robot.get_observation()
|
observation = robot.get_observation()
|
||||||
action = teleop_device.get_action()
|
action = teleop_device.get_action()
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
log_rerun_data(observation=observation, action=action)
|
|
||||||
|
|
||||||
elapsed_time = time.perf_counter() - start_time
|
|
||||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
|
||||||
if sleep_time > 0:
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
@@ -216,11 +202,10 @@ lerobot-record \
|
|||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
```python
|
```python
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
from lerobot.common.control_utils import init_keyboard_listener
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun
|
||||||
@@ -233,56 +218,71 @@ EPISODE_TIME_SEC = 60
|
|||||||
RESET_TIME_SEC = 10
|
RESET_TIME_SEC = 10
|
||||||
TASK_DESCRIPTION = "My task description"
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
|
||||||
def main():
|
# Create robot configuration
|
||||||
# Create robot configuration
|
robot_config = SO100FollowerConfig(
|
||||||
robot_config = SO101FollowerConfig(
|
id="my_awesome_follower_arm",
|
||||||
port="/dev/tty.usbmodem5AB90687491",
|
cameras={
|
||||||
id="my_follower_arm",
|
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||||
cameras={
|
},
|
||||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
port="/dev/tty.usbmodem58760434471",
|
||||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
)
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
teleop_config = SO101LeaderConfig(
|
teleop_config = SO100LeaderConfig(
|
||||||
port="/dev/tty.usbmodem5AB90689011",
|
id="my_awesome_leader_arm",
|
||||||
id="my_leader_arm",
|
port="/dev/tty.usbmodem585A0077581",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the robot and teleoperator
|
# Initialize the robot and teleoperator
|
||||||
robot = SO101Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
teleop = SO101Leader(teleop_config)
|
teleop = SO100Leader(teleop_config)
|
||||||
|
|
||||||
# Configure the dataset features
|
# Configure the dataset features
|
||||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||||
dataset_features = {**action_features, **obs_features}
|
dataset_features = {**action_features, **obs_features}
|
||||||
|
|
||||||
# Create the dataset
|
# Create the dataset
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id="<hf_username>/<dataset_repo_id>",
|
repo_id="<hf_username>/<dataset_repo_id>",
|
||||||
|
fps=FPS,
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=robot.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the keyboard listener and rerun visualization
|
||||||
|
_, events = init_keyboard_listener()
|
||||||
|
init_rerun(session_name="recording")
|
||||||
|
|
||||||
|
# Connect the robot and teleoperator
|
||||||
|
robot.connect()
|
||||||
|
teleop.connect()
|
||||||
|
|
||||||
|
# Create the required processors
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||||
|
|
||||||
|
episode_idx = 0
|
||||||
|
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||||
|
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
|
record_loop(
|
||||||
|
robot=robot,
|
||||||
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=dataset_features,
|
teleop_action_processor=teleop_action_processor,
|
||||||
robot_type=robot.name,
|
robot_action_processor=robot_action_processor,
|
||||||
use_videos=True,
|
robot_observation_processor=robot_observation_processor,
|
||||||
image_writer_threads=4,
|
teleop=teleop,
|
||||||
|
dataset=dataset,
|
||||||
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
|
single_task=TASK_DESCRIPTION,
|
||||||
|
display_data=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
# Reset the environment if not stopping or re-recording
|
||||||
_, events = init_keyboard_listener()
|
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||||
init_rerun(session_name="recording")
|
log_say("Reset the environment")
|
||||||
|
|
||||||
# Connect the robot and teleoperator
|
|
||||||
robot.connect()
|
|
||||||
teleop.connect()
|
|
||||||
|
|
||||||
# Create the required processors
|
|
||||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
|
||||||
|
|
||||||
episode_idx = 0
|
|
||||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|
||||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
|
||||||
|
|
||||||
record_loop(
|
record_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
@@ -291,50 +291,26 @@ def main():
|
|||||||
robot_action_processor=robot_action_processor,
|
robot_action_processor=robot_action_processor,
|
||||||
robot_observation_processor=robot_observation_processor,
|
robot_observation_processor=robot_observation_processor,
|
||||||
teleop=teleop,
|
teleop=teleop,
|
||||||
dataset=dataset,
|
control_time_s=RESET_TIME_SEC,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
if events["rerecord_episode"]:
|
||||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
log_say("Re-recording episode")
|
||||||
log_say("Reset the environment")
|
events["rerecord_episode"] = False
|
||||||
record_loop(
|
events["exit_early"] = False
|
||||||
robot=robot,
|
dataset.clear_episode_buffer()
|
||||||
events=events,
|
continue
|
||||||
fps=FPS,
|
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
teleop=teleop,
|
|
||||||
control_time_s=RESET_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
dataset.save_episode()
|
||||||
log_say("Re-recording episode")
|
episode_idx += 1
|
||||||
events["rerecord_episode"] = False
|
|
||||||
events["exit_early"] = False
|
|
||||||
dataset.clear_episode_buffer()
|
|
||||||
continue
|
|
||||||
|
|
||||||
dataset.save_episode()
|
# Clean up
|
||||||
episode_idx += 1
|
log_say("Stop recording")
|
||||||
|
robot.disconnect()
|
||||||
# finalize dataset
|
teleop.disconnect()
|
||||||
log_say("Finalizing dataset...")
|
dataset.push_to_hub()
|
||||||
dataset.finalize()
|
|
||||||
# Clean up
|
|
||||||
log_say("Stop recording")
|
|
||||||
robot.disconnect()
|
|
||||||
teleop.disconnect()
|
|
||||||
dataset.push_to_hub()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
@@ -372,7 +348,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
|||||||
##### 2. Checkpointing and Resuming
|
##### 2. Checkpointing and Resuming
|
||||||
|
|
||||||
- Checkpoints are automatically created during recording.
|
- Checkpoints are automatically created during recording.
|
||||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||||
- To start recording from scratch, **manually delete** the dataset directory.
|
- To start recording from scratch, **manually delete** the dataset directory.
|
||||||
|
|
||||||
##### 3. Recording Parameters
|
##### 3. Recording Parameters
|
||||||
@@ -446,7 +422,7 @@ from lerobot.utils.utils import log_say
|
|||||||
|
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
|
|
||||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||||
|
|
||||||
robot = SO100Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
robot.connect()
|
robot.connect()
|
||||||
@@ -514,83 +490,6 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
|||||||
|
|
||||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||||
|
|
||||||
#### Train using Hugging Face Jobs
|
|
||||||
|
|
||||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
|
||||||
|
|
||||||
To run the training use this command:
|
|
||||||
|
|
||||||
<hfoptions id="train_with_hf_jobs">
|
|
||||||
<hfoption id="Command">
|
|
||||||
```bash
|
|
||||||
hf jobs run \
|
|
||||||
--flavor a10g-small \
|
|
||||||
--timeout 4h \
|
|
||||||
--secrets HF_TOKEN \
|
|
||||||
huggingface/lerobot-gpu:latest \
|
|
||||||
-- \
|
|
||||||
python -m lerobot.scripts.lerobot_train \
|
|
||||||
--dataset.repo_id=username/dataset \
|
|
||||||
--policy.type=act \
|
|
||||||
--steps=5000 \
|
|
||||||
--batch_size=16 \
|
|
||||||
--policy.device=cuda \
|
|
||||||
--policy.repo_id=username/your_policy \
|
|
||||||
--log_freq=100
|
|
||||||
```
|
|
||||||
</hfoption>
|
|
||||||
<hfoption id="API example">
|
|
||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
|
||||||
```python
|
|
||||||
from huggingface_hub import run_job, get_token
|
|
||||||
|
|
||||||
run_name = "act_so101_hf_jobs"
|
|
||||||
dataset_id = "username/dataset"
|
|
||||||
user_hub_id = "username"
|
|
||||||
|
|
||||||
command_args = [
|
|
||||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
|
||||||
"--dataset.repo_id", dataset_id,
|
|
||||||
"--policy.type", "act",
|
|
||||||
"--steps", "5000",
|
|
||||||
"--batch_size", "16",
|
|
||||||
"--num_workers", "4",
|
|
||||||
"--policy.device", "cuda",
|
|
||||||
"--log_freq", "100",
|
|
||||||
"--save_freq", "1000",
|
|
||||||
"--save_checkpoint", "true",
|
|
||||||
"--wandb.enable", "false",
|
|
||||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
|
||||||
|
|
||||||
job_info = run_job(
|
|
||||||
image="huggingface/lerobot-gpu:latest",
|
|
||||||
command=command_args,
|
|
||||||
flavor="a10g-small",
|
|
||||||
timeout="4h",
|
|
||||||
secrets={"HF_TOKEN": get_token()}
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n🚀 Job successfully launched!")
|
|
||||||
print(f"🔹 Job ID: {job_info.id}")
|
|
||||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
|
||||||
```
|
|
||||||
<!-- prettier-ignore-end -->
|
|
||||||
|
|
||||||
</hfoption>
|
|
||||||
</hfoptions>
|
|
||||||
|
|
||||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
|
||||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
|
||||||
For longer training sessions increase the timeout.
|
|
||||||
|
|
||||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
|
||||||
|
|
||||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
|
||||||
|
|
||||||
#### Upload policy checkpoints
|
#### Upload policy checkpoints
|
||||||
|
|
||||||
Once training is done, upload the latest checkpoint with:
|
Once training is done, upload the latest checkpoint with:
|
||||||
|
|||||||
@@ -1,219 +0,0 @@
|
|||||||
# Quickstart
|
|
||||||
|
|
||||||
This is the **shortest path** from an unboxed SO-101 to a policy that drives your own robot. Every step is copy-paste; replace the **`<placeholders>`** with the values for your setup.
|
|
||||||
|
|
||||||
By the end you will have:
|
|
||||||
|
|
||||||
- A calibrated SO-101 leader + follower pair.
|
|
||||||
- A dataset of 30 episodes pushed to the Hugging Face Hub.
|
|
||||||
- A trained ACT policy (~20k steps) running on your robot via `lerobot-rollout`.
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> **How long will this take?**
|
|
||||||
> Recording 30 episodes is roughly 30–60 minutes of teleoperation. Training ACT for 20k steps takes ~1.5h on an A100, a few hours on a laptop RTX 3060, longer on Apple Silicon (`mps`). The commands themselves are quick — most of the wall-clock is data collection and training.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> If you only want to **understand the codebase** or **train on an existing dataset without hardware**, this page isn't for you. Read [Core concepts](./core_concepts) first, then jump to [Imitation learning end-to-end](./il_robots).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Before you start
|
|
||||||
|
|
||||||
You need:
|
|
||||||
|
|
||||||
- An **assembled SO-101 leader + follower pair**. If your robot is not assembled yet, follow the [SO-101 assembly guide](./so101) and come back here.
|
|
||||||
- **One or two cameras** (USB webcam works fine).
|
|
||||||
- A **CUDA GPU with ≥ 6 GB VRAM** (ACT is light — a laptop RTX 3060 works). Apple Silicon (`mps`) and CPU are supported but slower. See the [compute hardware guide](./hardware_guide) for sizing.
|
|
||||||
- A **Hugging Face account** — datasets and the trained policy will be pushed to your Hub.
|
|
||||||
|
|
||||||
If any of the above is missing, fix it first; the rest of the page assumes it.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — Install LeRobot
|
|
||||||
|
|
||||||
Follow the full [Installation Guide](./installation) for environment setup, then add the SO-101 motor stack and log in to the Hub:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install 'lerobot[feetech]'
|
|
||||||
git lfs install && git lfs pull
|
|
||||||
hf auth login # paste a token from https://huggingface.co/settings/tokens
|
|
||||||
```
|
|
||||||
|
|
||||||
Sanity check — the CLI entry points should be available:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-find-port --help
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Identify USB ports and motor IDs
|
|
||||||
|
|
||||||
Plug **only the follower arm** in (USB + power) and run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-find-port
|
|
||||||
```
|
|
||||||
|
|
||||||
When prompted, unplug it and press Enter. Note the printed port — that's your `<FOLLOWER_PORT>`. Repeat with only the **leader arm** plugged in to get `<LEADER_PORT>`.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> On Linux, USB ports look like `/dev/ttyACM0`; on macOS like `/dev/tty.usbmodem...`. On Linux you may need `sudo chmod 666 /dev/ttyACM0` to grant access.
|
|
||||||
|
|
||||||
If your motors are brand-new (or repurposed), set their IDs and baudrate **once per arm**:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-setup-motors --robot.type=so101_follower --robot.port=<FOLLOWER_PORT>
|
|
||||||
lerobot-setup-motors --teleop.type=so101_leader --teleop.port=<LEADER_PORT>
|
|
||||||
```
|
|
||||||
|
|
||||||
The script walks you through connecting motors one at a time. Full details: [SO-101 → Configure the motors](./so101#configure-the-motors).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Calibrate
|
|
||||||
|
|
||||||
Center every joint roughly in the middle of its range, then run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-calibrate \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower
|
|
||||||
|
|
||||||
lerobot-calibrate \
|
|
||||||
--teleop.type=so101_leader \
|
|
||||||
--teleop.port=<LEADER_PORT> \
|
|
||||||
--teleop.id=my_leader
|
|
||||||
```
|
|
||||||
|
|
||||||
After pressing Enter, sweep each joint through its full range of motion, then press Enter again to finish.
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> The `--robot.id` / `--teleop.id` values (`my_follower`, `my_leader`) become the **calibration keys**. Reuse the same IDs in every later command — that's how LeRobot finds the calibration on disk.
|
|
||||||
|
|
||||||
Watch the [calibration video](./so101#calibrate) if anything is unclear.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 4 — Teleoperate (sanity check, no recording)
|
|
||||||
|
|
||||||
Before recording anything, confirm the leader drives the follower correctly:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-teleoperate \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower \
|
|
||||||
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
|
||||||
--teleop.type=so101_leader \
|
|
||||||
--teleop.port=<LEADER_PORT> \
|
|
||||||
--teleop.id=my_leader \
|
|
||||||
--display_data=true
|
|
||||||
```
|
|
||||||
|
|
||||||
A Rerun window should open showing the camera feed and joint angles. Move the leader — the follower should mirror it in real time. If it doesn't, see [Troubleshooting & FAQ](./troubleshooting).
|
|
||||||
|
|
||||||
Don't know which camera index is which? Run `lerobot-find-cameras` — it saves a frame from each detected camera so you can pick the right one.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 5 — Record a dataset (30 episodes)
|
|
||||||
|
|
||||||
Now record demonstrations. Pick a short, repeatable task (e.g. *"put the red brick in the bowl"*). The dataset is pushed to the Hub under your username:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export HF_USER=<your-hf-username>
|
|
||||||
|
|
||||||
lerobot-record \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower \
|
|
||||||
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30} }" \
|
|
||||||
--teleop.type=so101_leader \
|
|
||||||
--teleop.port=<LEADER_PORT> \
|
|
||||||
--teleop.id=my_leader \
|
|
||||||
--dataset.repo_id=${HF_USER}/so101_quickstart \
|
|
||||||
--dataset.num_episodes=30 \
|
|
||||||
--dataset.single_task="Put the red brick in the bowl" \
|
|
||||||
--dataset.streaming_encoding=true \
|
|
||||||
--display_data=true
|
|
||||||
```
|
|
||||||
|
|
||||||
**Keyboard controls during recording:**
|
|
||||||
|
|
||||||
- **`→` (Right Arrow)** — save the current episode and move to the next.
|
|
||||||
- **`←` (Left Arrow)** — discard the current episode and retry.
|
|
||||||
- **`Esc`** — stop, encode videos, and upload to the Hub.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> **Quality beats quantity.** 30 clean, varied episodes (different brick positions, lighting, camera shake) train a much better policy than 100 identical ones. Move the object around. Vary your speed slightly.
|
|
||||||
|
|
||||||
When you're done, your dataset lives at `https://huggingface.co/datasets/${HF_USER}/so101_quickstart`. You can preview it in the browser. For deeper recording options (resume, multiple tasks, custom processors), see [Imitation learning end-to-end → Record](./il_robots#record-a-dataset).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 6 — Train ACT
|
|
||||||
|
|
||||||
ACT (Action Chunking Transformer) is the right default for a first run — small, fast, and works well on 30 episodes.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-train \
|
|
||||||
--dataset.repo_id=${HF_USER}/so101_quickstart \
|
|
||||||
--policy.type=act \
|
|
||||||
--output_dir=outputs/train/act_so101_quickstart \
|
|
||||||
--job_name=act_so101_quickstart \
|
|
||||||
--policy.device=cuda \
|
|
||||||
--policy.repo_id=${HF_USER}/act_so101_quickstart \
|
|
||||||
--steps=20000 \
|
|
||||||
--wandb.enable=true
|
|
||||||
```
|
|
||||||
|
|
||||||
A few notes:
|
|
||||||
|
|
||||||
- Replace `--policy.device=cuda` with `mps` on Apple Silicon, or `cpu` if you have no GPU (very slow — not recommended for a real run).
|
|
||||||
- `--wandb.enable=true` is optional. If you use it, run `wandb login` first. Otherwise drop the flag.
|
|
||||||
- Checkpoints land in `outputs/train/act_so101_quickstart/checkpoints/`. The final model is also pushed to the Hub at the `--policy.repo_id` you specified.
|
|
||||||
- To resume from an interruption: `lerobot-train --config_path=outputs/train/act_so101_quickstart/checkpoints/last/pretrained_model/train_config.json --resume=true`.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> **No GPU locally?** Train on Google Colab using the [ACT notebook](./notebooks#training-act), or rent a GPU via [Hugging Face Jobs](./il_robots#train-using-hugging-face-jobs) — pay-as-you-go, no setup.
|
|
||||||
|
|
||||||
For why ACT is the default and when to switch to SmolVLA, Pi0, or another policy, see [Choosing a policy](./policies_overview).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 7 — Run your policy on the robot
|
|
||||||
|
|
||||||
Deploy with `lerobot-rollout`. **Use the same camera layout you used while recording** — keys and resolutions must match.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-rollout \
|
|
||||||
--strategy.type=base \
|
|
||||||
--policy.path=${HF_USER}/act_so101_quickstart \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower \
|
|
||||||
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30} }" \
|
|
||||||
--task="Put the red brick in the bowl" \
|
|
||||||
--duration=60
|
|
||||||
```
|
|
||||||
|
|
||||||
`--duration` is in seconds — leave it off to run until you stop the script. You should see the follower arm move on its own, attempting the task.
|
|
||||||
|
|
||||||
If observations from the robot use different keys than the policy expects, you'll need a [rename map](./rename_map). If latency matters, look at [async inference](./async) and [real-time chunking](./rtc).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## You're done 🎉
|
|
||||||
|
|
||||||
You now have a working IL pipeline end-to-end. From here, the natural next steps are:
|
|
||||||
|
|
||||||
- **Improve the policy** — record more diverse episodes, train longer, or try a stronger model. See [Choosing a policy](./policies_overview).
|
|
||||||
- **Go deeper on imitation learning** — [Imitation learning end-to-end](./il_robots) covers multi-camera setups, multi-task datasets, episode replay, evaluation, and Hugging Face Jobs.
|
|
||||||
- **Try RL with a human in the loop** — [HIL-SERL](./hilserl) trains a policy that improves while you correct it.
|
|
||||||
- **Use a different robot** — see [Supported robots](./so101) for low-cost arms, mobile platforms, bimanual, and humanoid.
|
|
||||||
- **Build something new** — [Bring your own hardware](./integrate_hardware) and [Add a new policy](./bring_your_own_policies).
|
|
||||||
|
|
||||||
Stuck on something? Check [Troubleshooting & FAQ](./troubleshooting), or ask on [Discord](https://discord.gg/s3KuuzsPFb).
|
|
||||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
|||||||
Once you are logged in, you can run inference in your setup by doing:
|
Once you are logged in, you can run inference in your setup by doing:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-rollout \
|
lerobot-record \
|
||||||
--strategy.type=base \
|
|
||||||
--robot.type=so101_follower \
|
--robot.type=so101_follower \
|
||||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||||
# <- RTC optional, use when running on low power hardware \
|
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||||
# --inference.type=rtc \
|
--dataset.episode_time_s=50 \
|
||||||
# --inference.rtc.execution_horizon=10 \
|
--dataset.num_episodes=10 \
|
||||||
# --inference.rtc.max_guidance_weight=10.0 \
|
--dataset.streaming_encoding=true \
|
||||||
|
--dataset.encoder_threads=2 \
|
||||||
|
# --dataset.camera_encoder.vcodec=auto \
|
||||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||||
# --teleop.type=so100_leader \
|
# --teleop.type=so100_leader \
|
||||||
# --teleop.port=/dev/ttyACM0 \
|
# --teleop.port=/dev/ttyACM0 \
|
||||||
# --teleop.id=my_red_leader_arm \
|
# --teleop.id=my_red_leader_arm \
|
||||||
# --display_data=true #optional use if you want to see the camera stream \
|
|
||||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||||
|
|
||||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||||
of the source video, draws a progress line on each frame, and writes the result.
|
of the source video, draws a progress line on each frame, and writes the result.
|
||||||
The progress data is read from a parquet file that lives alongside the dataset
|
|
||||||
(configurable via ``--progress-file``).
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python examples/dataset/create_progress_videos.py \
|
python examples/dataset/create_progress_videos.py \
|
||||||
@@ -58,26 +56,22 @@ SCORE_FONT_SCALE = 0.8
|
|||||||
TASK_FONT_SCALE = 0.55
|
TASK_FONT_SCALE = 0.55
|
||||||
|
|
||||||
|
|
||||||
def download_episode_metadata(
|
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
"""Download only the metadata and sarm_progress files for a dataset.
|
||||||
) -> Path:
|
|
||||||
"""Download only the metadata and per-frame progress file for a dataset.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id: HuggingFace dataset repository ID.
|
repo_id: HuggingFace dataset repository ID.
|
||||||
episode: Episode index (used for logging only; all meta is fetched).
|
episode: Episode index (used for logging only; all meta is fetched).
|
||||||
progress_file: Filename of the per-frame progress parquet inside the
|
|
||||||
dataset repo.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Local cache path for the downloaded snapshot.
|
Local cache path for the downloaded snapshot.
|
||||||
"""
|
"""
|
||||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||||
local_path = Path(
|
local_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
allow_patterns=["meta/**", progress_file],
|
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||||
ignore_patterns=["*.mp4"],
|
ignore_patterns=["*.mp4"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -221,28 +215,25 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
|||||||
return video_path
|
return video_path
|
||||||
|
|
||||||
|
|
||||||
def load_progress_data(
|
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
"""Load sarm_progress values for an episode.
|
||||||
) -> np.ndarray | None:
|
|
||||||
"""Load per-frame progress values for an episode.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_path: Dataset cache root.
|
local_path: Dataset cache root.
|
||||||
episode: Episode index.
|
episode: Episode index.
|
||||||
progress_file: Filename of the per-frame progress parquet.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||||
"""
|
"""
|
||||||
parquet_path = local_path / progress_file
|
parquet_path = local_path / "sarm_progress.parquet"
|
||||||
if not parquet_path.exists():
|
if not parquet_path.exists():
|
||||||
logging.warning("%s not found", progress_file)
|
logging.warning("sarm_progress.parquet not found")
|
||||||
return None
|
return None
|
||||||
df = pd.read_parquet(parquet_path)
|
df = pd.read_parquet(parquet_path)
|
||||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||||
episode_df = df[df["episode_index"] == episode].copy()
|
episode_df = df[df["episode_index"] == episode].copy()
|
||||||
if episode_df.empty:
|
if episode_df.empty:
|
||||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||||
return None
|
return None
|
||||||
episode_df = episode_df.sort_values("frame_index")
|
episode_df = episode_df.sort_values("frame_index")
|
||||||
|
|
||||||
@@ -585,7 +576,6 @@ def process_dataset(
|
|||||||
camera_key: str | None,
|
camera_key: str | None,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
create_gif: bool = False,
|
create_gif: bool = False,
|
||||||
progress_file: str = "sarm_progress.parquet",
|
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||||
|
|
||||||
@@ -595,8 +585,6 @@ def process_dataset(
|
|||||||
camera_key: Camera key to use, or None for auto-selection.
|
camera_key: Camera key to use, or None for auto-selection.
|
||||||
output_dir: Directory to write output files.
|
output_dir: Directory to write output files.
|
||||||
create_gif: If True, also generate a GIF from the MP4.
|
create_gif: If True, also generate a GIF from the MP4.
|
||||||
progress_file: Filename of the per-frame progress parquet inside the
|
|
||||||
dataset repo.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the final output file, or None on failure.
|
Path to the final output file, or None on failure.
|
||||||
@@ -604,7 +592,7 @@ def process_dataset(
|
|||||||
safe_name = repo_id.replace("/", "_")
|
safe_name = repo_id.replace("/", "_")
|
||||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||||
|
|
||||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
local_path = download_episode_metadata(repo_id, episode)
|
||||||
logging.info(" Local cache: %s", local_path)
|
logging.info(" Local cache: %s", local_path)
|
||||||
|
|
||||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||||
@@ -612,9 +600,9 @@ def process_dataset(
|
|||||||
|
|
||||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||||
|
|
||||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
progress_data = load_progress_data(local_path, episode)
|
||||||
if progress_data is None:
|
if progress_data is None:
|
||||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logging.info(" Progress frames: %d", len(progress_data))
|
logging.info(" Progress frames: %d", len(progress_data))
|
||||||
@@ -639,7 +627,7 @@ def process_dataset(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
@@ -670,15 +658,6 @@ def main() -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Also generate a GIF from the MP4 output.",
|
help="Also generate a GIF from the MP4 output.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--progress-file",
|
|
||||||
type=str,
|
|
||||||
default="sarm_progress.parquet",
|
|
||||||
help=(
|
|
||||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
|
||||||
"(default: 'sarm_progress.parquet')."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
@@ -691,7 +670,6 @@ def main() -> None:
|
|||||||
camera_key=args.camera_key,
|
camera_key=args.camera_key,
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
create_gif=args.gif,
|
create_gif=args.gif,
|
||||||
progress_file=args.progress_file,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
@@ -138,9 +138,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
|||||||
# Common
|
# Common
|
||||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
|
||||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
|
||||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
|
|||||||
@@ -18,25 +18,12 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import _placo_available, require_package
|
||||||
|
|
||||||
_placo_runtime_error: ImportError | None = None
|
if TYPE_CHECKING or _placo_available:
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import placo # type: ignore[import-not-found]
|
import placo # type: ignore[import-not-found]
|
||||||
else:
|
else:
|
||||||
try:
|
placo = None
|
||||||
import placo # type: ignore[import-not-found]
|
|
||||||
except ImportError as _placo_import_err:
|
|
||||||
placo = None
|
|
||||||
_placo_runtime_error = _placo_import_err
|
|
||||||
|
|
||||||
|
|
||||||
def _raise_if_placo_unusable() -> None:
|
|
||||||
if placo is None and _placo_runtime_error is not None:
|
|
||||||
raise ImportError(
|
|
||||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
|
||||||
) from _placo_runtime_error
|
|
||||||
|
|
||||||
|
|
||||||
class RobotKinematics:
|
class RobotKinematics:
|
||||||
@@ -57,7 +44,6 @@ class RobotKinematics:
|
|||||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||||
"""
|
"""
|
||||||
require_package("placo", extra="placo-dep")
|
require_package("placo", extra="placo-dep")
|
||||||
_raise_if_placo_unusable()
|
|
||||||
|
|
||||||
self.robot = placo.RobotWrapper(urdf_path)
|
self.robot = placo.RobotWrapper(urdf_path)
|
||||||
self.solver = placo.KinematicsSolver(self.robot)
|
self.solver = placo.KinematicsSolver(self.robot)
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ from .tables import (
|
|||||||
CAN_CMD_SET_ZERO,
|
CAN_CMD_SET_ZERO,
|
||||||
DEFAULT_BAUDRATE,
|
DEFAULT_BAUDRATE,
|
||||||
DEFAULT_TIMEOUT_MS,
|
DEFAULT_TIMEOUT_MS,
|
||||||
HANDSHAKE_TIMEOUT_S,
|
|
||||||
MODEL_RESOLUTION,
|
MODEL_RESOLUTION,
|
||||||
MOTOR_LIMIT_PARAMS,
|
MOTOR_LIMIT_PARAMS,
|
||||||
NORMALIZED_DATA,
|
NORMALIZED_DATA,
|
||||||
@@ -216,16 +215,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
self._is_connected = False
|
self._is_connected = False
|
||||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||||
|
|
||||||
def _query_status_via_clear_fault(
|
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
|
||||||
) -> tuple[bool, can.Message | None]:
|
|
||||||
motor_name = self._get_motor_name(motor)
|
motor_name = self._get_motor_name(motor)
|
||||||
motor_id = self._get_motor_id(motor_name)
|
motor_id = self._get_motor_id(motor_name)
|
||||||
recv_id = self._get_motor_recv_id(motor_name)
|
recv_id = self._get_motor_recv_id(motor_name)
|
||||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
self._bus().send(msg)
|
self._bus().send(msg)
|
||||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||||
|
|
||||||
def _recv_status_via_clear_fault(
|
def _recv_status_via_clear_fault(
|
||||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||||
@@ -283,7 +280,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
faulted_motors = []
|
faulted_motors = []
|
||||||
|
|
||||||
for motor_name in self.motors:
|
for motor_name in self.motors:
|
||||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||||
if msg is None:
|
if msg is None:
|
||||||
missing_motors.append(motor_name)
|
missing_motors.append(motor_name)
|
||||||
elif has_fault:
|
elif has_fault:
|
||||||
@@ -508,87 +505,6 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def _recv_all_messages_until_quiet(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
timeout: float = RUNNING_TIMEOUT,
|
|
||||||
max_messages: int = 4096,
|
|
||||||
) -> list[can.Message]:
|
|
||||||
"""
|
|
||||||
Receive frames until the bus goes quiet.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Poll timeout used for each recv() call. Collection stops
|
|
||||||
when one recv() times out (quiet gap).
|
|
||||||
max_messages: Safety cap to prevent unbounded loops.
|
|
||||||
"""
|
|
||||||
out: list[can.Message] = []
|
|
||||||
max_messages = max(1, max_messages)
|
|
||||||
timeout = max(0.0, timeout)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while len(out) < max_messages:
|
|
||||||
msg = self._bus().recv(timeout=timeout)
|
|
||||||
if msg is None:
|
|
||||||
break
|
|
||||||
out.append(msg)
|
|
||||||
except (can.CanError, OSError) as e:
|
|
||||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
|
||||||
"""
|
|
||||||
Decode all received feedback frames and update cached motor states.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of payload recv_ids that were successfully mapped to motors.
|
|
||||||
"""
|
|
||||||
processed_recv_ids: set[int] = set()
|
|
||||||
for msg in messages:
|
|
||||||
if len(msg.data) < 1:
|
|
||||||
logger.debug(
|
|
||||||
f"Dropping short CAN frame on {self.port} "
|
|
||||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
recv_id = int(msg.data[0])
|
|
||||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
|
||||||
if motor_name is None:
|
|
||||||
logger.debug(
|
|
||||||
f"Unmapped CAN frame on {self.port} "
|
|
||||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._process_response(motor_name, msg)
|
|
||||||
processed_recv_ids.add(recv_id)
|
|
||||||
|
|
||||||
return processed_recv_ids
|
|
||||||
|
|
||||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
|
||||||
"""
|
|
||||||
Drain pending RX frames from the CAN interface.
|
|
||||||
|
|
||||||
This is used by higher-level controllers to drop stale feedback before issuing
|
|
||||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
|
||||||
It should also be called once when a controller instance is created/connected,
|
|
||||||
to clear residual frames left on the interface from previous sessions.
|
|
||||||
"""
|
|
||||||
drained = 0
|
|
||||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
|
||||||
max_messages = max(1, max_messages)
|
|
||||||
try:
|
|
||||||
while drained < max_messages:
|
|
||||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
|
||||||
if msg is None:
|
|
||||||
break
|
|
||||||
drained += 1
|
|
||||||
except (can.CanError, OSError) as e:
|
|
||||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
|
||||||
return drained
|
|
||||||
|
|
||||||
def _speed_control(
|
def _speed_control(
|
||||||
self,
|
self,
|
||||||
motor: NameOrID,
|
motor: NameOrID,
|
||||||
@@ -728,14 +644,11 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
self._bus().send(msg)
|
self._bus().send(msg)
|
||||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
|
||||||
# This avoids dropping useful frames when responses from different motors interleave.
|
|
||||||
messages = self._recv_all_messages_until_quiet()
|
|
||||||
processed_recv_ids = self._process_feedback_messages(messages)
|
|
||||||
|
|
||||||
|
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||||
for recv_id, motor_name in recv_id_to_motor.items():
|
for recv_id, motor_name in recv_id_to_motor.items():
|
||||||
if recv_id not in processed_recv_ids:
|
if msg := responses.get(recv_id):
|
||||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
self._process_response(motor_name, msg)
|
||||||
|
|
||||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||||
"""Convert float to unsigned integer for CAN transmission."""
|
"""Convert float to unsigned integer for CAN transmission."""
|
||||||
@@ -798,10 +711,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
try:
|
try:
|
||||||
self._decode_motor_state(msg.data)
|
self._decode_motor_state(msg.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||||
f"Failed to decode response from {motor} "
|
|
||||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||||
"""Retrieve a specific value from the state cache."""
|
"""Retrieve a specific value from the state cache."""
|
||||||
@@ -938,12 +848,20 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
self._bus().send(msg)
|
self._bus().send(msg)
|
||||||
updated_motors.append(motor)
|
updated_motors.append(motor)
|
||||||
|
|
||||||
messages = self._recv_all_messages_until_quiet()
|
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||||
processed_recv_ids = self._process_feedback_messages(messages)
|
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||||
|
|
||||||
|
for response in responses.values():
|
||||||
|
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||||
|
if payload_motor_name is not None:
|
||||||
|
self._process_response(payload_motor_name, response)
|
||||||
|
else:
|
||||||
|
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||||
|
self._decode_motor_state(response.data)
|
||||||
|
|
||||||
for motor in updated_motors:
|
for motor in updated_motors:
|
||||||
recv_id = self._get_motor_recv_id(motor)
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
if recv_id not in processed_recv_ids:
|
if recv_id not in responses:
|
||||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||||
|
|
||||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
|
|||||||
@@ -114,8 +114,7 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
|||||||
CAN_PARAM_ID = 0x7FF
|
CAN_PARAM_ID = 0x7FF
|
||||||
|
|
||||||
|
|
||||||
RUNNING_TIMEOUT = 0.003
|
RUNNING_TIMEOUT = 0.001
|
||||||
HANDSHAKE_TIMEOUT_S = 0.05
|
|
||||||
PARAM_TIMEOUT = 0.01
|
PARAM_TIMEOUT = 0.01
|
||||||
|
|
||||||
STATE_CACHE_TTL_S = 0.02
|
STATE_CACHE_TTL_S = 0.02
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -26,14 +26,9 @@ from lerobot.utils.import_utils import _transformers_available
|
|||||||
|
|
||||||
# Conditional import for type checking and lazy loading
|
# Conditional import for type checking and lazy loading
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from huggingface_hub.dataclasses import strict
|
|
||||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||||
from transformers.feature_extraction_utils import BatchFeature
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def strict(cls):
|
|
||||||
return cls
|
|
||||||
|
|
||||||
AutoConfig = None
|
AutoConfig = None
|
||||||
AutoModel = None
|
AutoModel = None
|
||||||
PretrainedConfig = object
|
PretrainedConfig = object
|
||||||
@@ -178,20 +173,19 @@ N_COLOR_CHANNELS = 3
|
|||||||
|
|
||||||
|
|
||||||
# config
|
# config
|
||||||
@strict
|
|
||||||
class GR00TN15Config(PretrainedConfig):
|
class GR00TN15Config(PretrainedConfig):
|
||||||
model_type = "gr00t_n1_5"
|
model_type = "gr00t_n1_5"
|
||||||
|
|
||||||
backbone_cfg: dict[str, Any] | None = None
|
backbone_cfg: dict
|
||||||
action_head_cfg: dict[str, Any] | None = None
|
action_head_cfg: dict
|
||||||
action_horizon: int = 0
|
action_horizon: int
|
||||||
action_dim: int = 0
|
action_dim: int
|
||||||
compute_dtype: str = "float32"
|
compute_dtype: str = "float32"
|
||||||
|
|
||||||
def __post_init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
super().__init__(**kwargs)
|
||||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
for key, value in kwargs.items():
|
||||||
super().__post_init__(**kwargs)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
# real model
|
# real model
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
|||||||
|
|
||||||
# Conditional import for type checking and lazy loading
|
# Conditional import for type checking and lazy loading
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.cache_utils import DynamicCache
|
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
DynamicCache = None
|
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
PiGemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
_gated_residual = None
|
_gated_residual = None
|
||||||
@@ -142,15 +141,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
|||||||
return att_2d_masks & pad_2d_masks
|
return att_2d_masks & pad_2d_masks
|
||||||
|
|
||||||
|
|
||||||
def clone_past_key_values(past_key_values):
|
|
||||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
|
||||||
return DynamicCache(
|
|
||||||
tuple(
|
|
||||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_vector(vector, new_dim):
|
def pad_vector(vector, new_dim):
|
||||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||||
|
|
||||||
@@ -237,13 +227,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
|
|
||||||
|
|
||||||
# Define the complete layer computation function for gradient checkpointing
|
# Define the complete layer computation function for gradient checkpointing
|
||||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
def compute_layer_complete(
|
||||||
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
|
):
|
||||||
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = layers[i]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
@@ -265,16 +258,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
|||||||
device=query_states.device,
|
device=query_states.device,
|
||||||
dtype=query_states.dtype,
|
dtype=query_states.dtype,
|
||||||
)
|
)
|
||||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
paligemma_layer = layers[0]
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
scaling = paligemma_layer.self_attn.scaling
|
|
||||||
# Attention computation
|
# Attention computation
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
paligemma_layer.self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -282,13 +274,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma_layer.self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = layers[i]
|
layer = models[i].layers[layer_idx]
|
||||||
end_pos = start_pos + hidden_states.shape[1]
|
end_pos = start_pos + hidden_states.shape[1]
|
||||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
@@ -496,9 +488,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
paligemma_layers = self.paligemma.model.language_model.layers
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
gemma_expert_layers = self.gemma_expert.model.layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
use_gradient_checkpointing = (
|
use_gradient_checkpointing = (
|
||||||
@@ -508,39 +499,36 @@ class PaliGemmaWithExpertModel(
|
|||||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||||
|
|
||||||
# Process all layers with gradient checkpointing if enabled
|
# Process all layers with gradient checkpointing if enabled
|
||||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
for layer_idx in range(num_layers):
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||||
compute_layer_complete,
|
compute_layer_complete,
|
||||||
|
layer_idx,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
adarms_cond,
|
adarms_cond,
|
||||||
use_reentrant=False,
|
use_reentrant=False,
|
||||||
preserve_rng_state=False,
|
preserve_rng_state=False,
|
||||||
layers=layers,
|
paligemma=self.paligemma,
|
||||||
rotary_emb=rotary_emb,
|
gemma_expert=self.gemma_expert,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = compute_layer_complete(
|
inputs_embeds = compute_layer_complete(
|
||||||
|
layer_idx,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
adarms_cond,
|
adarms_cond,
|
||||||
layers=layers,
|
paligemma=self.paligemma,
|
||||||
rotary_emb=rotary_emb,
|
gemma_expert=self.gemma_expert,
|
||||||
)
|
)
|
||||||
|
|
||||||
# final norm
|
# final norm
|
||||||
final_norms = (
|
|
||||||
self.paligemma.model.language_model.norm,
|
|
||||||
self.gemma_expert.model.norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -919,7 +907,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
past_key_values = clone_past_key_values(past_key_values)
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
|||||||
|
|
||||||
# Conditional import for type checking and lazy loading
|
# Conditional import for type checking and lazy loading
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers.cache_utils import DynamicCache
|
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
from transformers.models.gemma import modeling_gemma
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
CONFIG_MAPPING = None
|
CONFIG_MAPPING = None
|
||||||
DynamicCache = None
|
|
||||||
modeling_gemma = None
|
modeling_gemma = None
|
||||||
PiGemmaForCausalLM = None
|
PiGemmaForCausalLM = None
|
||||||
_gated_residual = None
|
_gated_residual = None
|
||||||
@@ -139,15 +138,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
|||||||
return att_2d_masks & pad_2d_masks
|
return att_2d_masks & pad_2d_masks
|
||||||
|
|
||||||
|
|
||||||
def clone_past_key_values(past_key_values):
|
|
||||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
|
||||||
return DynamicCache(
|
|
||||||
tuple(
|
|
||||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_vector(vector, new_dim):
|
def pad_vector(vector, new_dim):
|
||||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||||
|
|
||||||
@@ -234,13 +224,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
|||||||
|
|
||||||
|
|
||||||
# Define the complete layer computation function for gradient checkpointing
|
# Define the complete layer computation function for gradient checkpointing
|
||||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
def compute_layer_complete(
|
||||||
|
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||||
|
):
|
||||||
|
models = [paligemma.model.language_model, gemma_expert.model]
|
||||||
query_states = []
|
query_states = []
|
||||||
key_states = []
|
key_states = []
|
||||||
value_states = []
|
value_states = []
|
||||||
gates = []
|
gates = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = layers[i]
|
layer = models[i].layers[layer_idx]
|
||||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||||
gates.append(gate)
|
gates.append(gate)
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
@@ -262,16 +255,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
|||||||
device=query_states.device,
|
device=query_states.device,
|
||||||
dtype=query_states.dtype,
|
dtype=query_states.dtype,
|
||||||
)
|
)
|
||||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
)
|
)
|
||||||
batch_size = query_states.shape[0]
|
batch_size = query_states.shape[0]
|
||||||
paligemma_layer = layers[0]
|
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||||
scaling = paligemma_layer.self_attn.scaling
|
|
||||||
# Attention computation
|
# Attention computation
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
paligemma_layer.self_attn,
|
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@@ -279,13 +271,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
|||||||
scaling,
|
scaling,
|
||||||
)
|
)
|
||||||
# Get head_dim from the current layer, not from the model
|
# Get head_dim from the current layer, not from the model
|
||||||
head_dim = paligemma_layer.self_attn.head_dim
|
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||||
# Process layer outputs
|
# Process layer outputs
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
layer = layers[i]
|
layer = models[i].layers[layer_idx]
|
||||||
end_pos = start_pos + hidden_states.shape[1]
|
end_pos = start_pos + hidden_states.shape[1]
|
||||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
@@ -493,9 +485,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
prefix_output = None
|
prefix_output = None
|
||||||
prefix_past_key_values = None
|
prefix_past_key_values = None
|
||||||
else:
|
else:
|
||||||
paligemma_layers = self.paligemma.model.language_model.layers
|
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||||
gemma_expert_layers = self.gemma_expert.model.layers
|
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
# Check if gradient checkpointing is enabled for any of the models
|
||||||
use_gradient_checkpointing = (
|
use_gradient_checkpointing = (
|
||||||
@@ -505,39 +496,36 @@ class PaliGemmaWithExpertModel(
|
|||||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||||
|
|
||||||
# Process all layers with gradient checkpointing if enabled
|
# Process all layers with gradient checkpointing if enabled
|
||||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
for layer_idx in range(num_layers):
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||||
compute_layer_complete,
|
compute_layer_complete,
|
||||||
|
layer_idx,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
adarms_cond,
|
adarms_cond,
|
||||||
use_reentrant=False,
|
use_reentrant=False,
|
||||||
preserve_rng_state=False,
|
preserve_rng_state=False,
|
||||||
layers=layers,
|
paligemma=self.paligemma,
|
||||||
rotary_emb=rotary_emb,
|
gemma_expert=self.gemma_expert,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs_embeds = compute_layer_complete(
|
inputs_embeds = compute_layer_complete(
|
||||||
|
layer_idx,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
adarms_cond,
|
adarms_cond,
|
||||||
layers=layers,
|
paligemma=self.paligemma,
|
||||||
rotary_emb=rotary_emb,
|
gemma_expert=self.gemma_expert,
|
||||||
)
|
)
|
||||||
|
|
||||||
# final norm
|
# final norm
|
||||||
final_norms = (
|
|
||||||
self.paligemma.model.language_model.norm,
|
|
||||||
self.gemma_expert.model.norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||||
outputs_embeds = []
|
outputs_embeds = []
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
for i, hidden_states in enumerate(inputs_embeds):
|
||||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||||
outputs_embeds.append(out_emb)
|
outputs_embeds.append(out_emb)
|
||||||
return outputs_embeds
|
return outputs_embeds
|
||||||
|
|
||||||
@@ -892,7 +880,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
past_key_values = clone_past_key_values(past_key_values)
|
past_key_values = copy.deepcopy(past_key_values)
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=full_att_2d_masks_4d,
|
attention_mask=full_att_2d_masks_4d,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
def generate_model_card(
|
def generate_model_card(
|
||||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||||
) -> ModelCard:
|
) -> ModelCard:
|
||||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
base_model_mapping = {
|
||||||
|
"smolvla": "lerobot/smolvla_base",
|
||||||
|
"pi0": "lerobot/pi0_base",
|
||||||
|
"pi05": "lerobot/pi05_base",
|
||||||
|
"pi0_fast": "lerobot/pi0fast-base",
|
||||||
|
"xvla": "lerobot/xvla-base",
|
||||||
|
}
|
||||||
|
|
||||||
card_data = ModelCardData(
|
card_data = ModelCardData(
|
||||||
license=license or "apache-2.0",
|
license=license or "apache-2.0",
|
||||||
@@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||||
model_name=model_type,
|
model_name=model_type,
|
||||||
datasets=dataset_repo_id,
|
datasets=dataset_repo_id,
|
||||||
base_model=base_model,
|
base_model=base_model_mapping(model_type, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
template_card = (
|
template_card = (
|
||||||
|
|||||||
@@ -73,14 +73,17 @@ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
|||||||
### Evaluate the policy/run inference
|
### Evaluate the policy/run inference
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout \
|
||||||
--robot.type=so100_follower \
|
--strategy.type=base \
|
||||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
--robot.type=so101_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||||
--episodes=10
|
--task="Put lego brick into the transparent box" \
|
||||||
|
--duration=60
|
||||||
```
|
```
|
||||||
|
|
||||||
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
|
If you want to record a dataset while testing the policy use `--dataset.repo_id=<hf_user>/eval_dataset_name` it is important to use the prefix **eval\_**. For the policy path use the policy from the Hugging Face Hub or a local one. Skipping duration will make the policy run indefinitely.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Config:
|
|
||||||
width: int
|
|
||||||
depth: int
|
|
||||||
mlp_dim: int
|
|
||||||
num_heads: int
|
|
||||||
num_kv_heads: int
|
|
||||||
head_dim: int
|
|
||||||
|
|
||||||
|
|
||||||
def get_config(variant: str) -> Config:
|
|
||||||
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
|
||||||
if variant == "dummy":
|
|
||||||
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
|
||||||
if variant == "gemma_300m":
|
|
||||||
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
|
||||||
if variant == "gemma_2b":
|
|
||||||
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
|
||||||
raise ValueError(f"Unknown variant: {variant}")
|
|
||||||
@@ -1,300 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers.models.auto import CONFIG_MAPPING
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
from lerobot.policies.pi_gemma import (
|
|
||||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
|
||||||
PiGemmaForCausalLM,
|
|
||||||
_gated_residual,
|
|
||||||
layernorm_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaWithExpertModel(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vlm_config,
|
|
||||||
action_expert_config,
|
|
||||||
use_adarms=None,
|
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
|
||||||
):
|
|
||||||
if use_adarms is None:
|
|
||||||
use_adarms = [False, False]
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
|
||||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
|
||||||
vlm_config_hf.image_token_index = 257152
|
|
||||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
|
||||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
|
||||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
|
||||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
|
||||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
|
||||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
|
||||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
|
||||||
vlm_config_hf.text_config.dtype = "float32"
|
|
||||||
vlm_config_hf.text_config.vocab_size = 257152
|
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
|
||||||
vlm_config_hf.vision_config.dtype = "float32"
|
|
||||||
|
|
||||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
|
||||||
head_dim=action_expert_config.head_dim,
|
|
||||||
hidden_size=action_expert_config.width,
|
|
||||||
intermediate_size=action_expert_config.mlp_dim,
|
|
||||||
num_attention_heads=action_expert_config.num_heads,
|
|
||||||
num_hidden_layers=action_expert_config.depth,
|
|
||||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
|
||||||
vocab_size=257152,
|
|
||||||
hidden_activation="gelu_pytorch_tanh",
|
|
||||||
dtype="float32",
|
|
||||||
use_adarms=use_adarms[1],
|
|
||||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
|
||||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
|
||||||
self.gemma_expert.model.embed_tokens = None
|
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
|
||||||
|
|
||||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
|
||||||
if precision == "bfloat16":
|
|
||||||
self.to(dtype=torch.bfloat16)
|
|
||||||
elif precision == "float32":
|
|
||||||
self.to(dtype=torch.float32)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid precision: {precision}")
|
|
||||||
|
|
||||||
params_to_keep_float32 = [
|
|
||||||
"vision_tower",
|
|
||||||
"multi_modal_projector",
|
|
||||||
"input_layernorm",
|
|
||||||
"post_attention_layernorm",
|
|
||||||
"model.norm",
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, param in self.named_parameters():
|
|
||||||
if any(selector in name for selector in params_to_keep_float32):
|
|
||||||
param.data = param.data.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
def embed_image(self, image: torch.Tensor):
|
|
||||||
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
|
|
||||||
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
|
|
||||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
|
|
||||||
out_dtype = image.dtype
|
|
||||||
if image.dtype != torch.float32:
|
|
||||||
image = image.to(torch.float32)
|
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
|
||||||
features = image_outputs.pooler_output
|
|
||||||
if features.dtype != out_dtype:
|
|
||||||
features = features.to(out_dtype)
|
|
||||||
return features
|
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
|
||||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
attention_mask: torch.Tensor | None = None,
|
|
||||||
position_ids: torch.LongTensor | None = None,
|
|
||||||
past_key_values: list[torch.FloatTensor] | None = None,
|
|
||||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
|
||||||
use_cache: bool | None = None,
|
|
||||||
adarms_cond: list[torch.Tensor] | None = None,
|
|
||||||
):
|
|
||||||
if adarms_cond is None:
|
|
||||||
adarms_cond = [None, None]
|
|
||||||
if inputs_embeds[1] is None:
|
|
||||||
prefix_output = self.paligemma.model.language_model.forward(
|
|
||||||
inputs_embeds=inputs_embeds[0],
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
|
||||||
)
|
|
||||||
prefix_past_key_values = prefix_output.past_key_values
|
|
||||||
prefix_output = prefix_output.last_hidden_state
|
|
||||||
suffix_output = None
|
|
||||||
elif inputs_embeds[0] is None:
|
|
||||||
suffix_output = self.gemma_expert.model.forward(
|
|
||||||
inputs_embeds=inputs_embeds[1],
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
|
||||||
)
|
|
||||||
suffix_output = suffix_output.last_hidden_state
|
|
||||||
prefix_output = None
|
|
||||||
prefix_past_key_values = None
|
|
||||||
else:
|
|
||||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
|
||||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
|
||||||
|
|
||||||
# Check if gradient checkpointing is enabled for any of the models
|
|
||||||
use_gradient_checkpointing = (
|
|
||||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
|
||||||
and self.gemma_expert.model.gradient_checkpointing
|
|
||||||
and self.training
|
|
||||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
|
||||||
|
|
||||||
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
|
||||||
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
|
||||||
if not self.gemma_expert.model.gradient_checkpointing:
|
|
||||||
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
|
||||||
self.gemma_expert.model.gradient_checkpointing = True
|
|
||||||
use_gradient_checkpointing = True
|
|
||||||
|
|
||||||
# Debug gradient checkpointing status
|
|
||||||
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
|
||||||
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
|
||||||
print(f"Model training mode: {self.training}")
|
|
||||||
print(
|
|
||||||
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
|
||||||
)
|
|
||||||
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
|
||||||
print(
|
|
||||||
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
|
||||||
)
|
|
||||||
self._debug_gc_printed = True
|
|
||||||
|
|
||||||
# Define the complete layer computation function for gradient checkpointing
|
|
||||||
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
|
||||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
|
||||||
|
|
||||||
query_states = []
|
|
||||||
key_states = []
|
|
||||||
value_states = []
|
|
||||||
gates = []
|
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
|
||||||
layer = models[i].layers[layer_idx]
|
|
||||||
hidden_states, gate = layernorm_forward(
|
|
||||||
layer.input_layernorm, hidden_states, adarms_cond[i]
|
|
||||||
)
|
|
||||||
gates.append(gate)
|
|
||||||
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
|
||||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
|
||||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
||||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
||||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
||||||
|
|
||||||
query_states.append(query_state)
|
|
||||||
key_states.append(key_state)
|
|
||||||
value_states.append(value_state)
|
|
||||||
|
|
||||||
# Concatenate and process attention
|
|
||||||
query_states = torch.cat(query_states, dim=2)
|
|
||||||
key_states = torch.cat(key_states, dim=2)
|
|
||||||
value_states = torch.cat(value_states, dim=2)
|
|
||||||
|
|
||||||
dummy_tensor = torch.zeros(
|
|
||||||
query_states.shape[0],
|
|
||||||
query_states.shape[2],
|
|
||||||
query_states.shape[-1],
|
|
||||||
device=query_states.device,
|
|
||||||
dtype=query_states.dtype,
|
|
||||||
)
|
|
||||||
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
|
||||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = query_states.shape[0]
|
|
||||||
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
|
||||||
|
|
||||||
# Attention computation
|
|
||||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
|
||||||
self.paligemma.model.language_model.layers[layer_idx].self_attn,
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attention_mask,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
# Get head_dim from the current layer, not from the model
|
|
||||||
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
|
||||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
|
||||||
|
|
||||||
# Process layer outputs
|
|
||||||
outputs_embeds = []
|
|
||||||
start_pos = 0
|
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
|
||||||
layer = models[i].layers[layer_idx]
|
|
||||||
end_pos = start_pos + hidden_states.shape[1]
|
|
||||||
|
|
||||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
|
||||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
|
||||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
|
||||||
|
|
||||||
# first residual
|
|
||||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
|
||||||
after_first_residual = out_emb.clone()
|
|
||||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
|
||||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
|
||||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
|
||||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
out_emb = layer.mlp(out_emb)
|
|
||||||
# second residual
|
|
||||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
|
||||||
outputs_embeds.append(out_emb)
|
|
||||||
start_pos = end_pos
|
|
||||||
|
|
||||||
return outputs_embeds
|
|
||||||
|
|
||||||
# Process all layers with gradient checkpointing if enabled
|
|
||||||
for layer_idx in range(num_layers):
|
|
||||||
if use_gradient_checkpointing:
|
|
||||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
||||||
compute_layer_complete,
|
|
||||||
layer_idx,
|
|
||||||
inputs_embeds,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
adarms_cond,
|
|
||||||
use_reentrant=False,
|
|
||||||
preserve_rng_state=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inputs_embeds = compute_layer_complete(
|
|
||||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
|
||||||
)
|
|
||||||
|
|
||||||
# Old code removed - now using compute_layer_complete function above
|
|
||||||
|
|
||||||
# final norm
|
|
||||||
# Define final norm computation function for gradient checkpointing
|
|
||||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
|
||||||
outputs_embeds = []
|
|
||||||
for i, hidden_states in enumerate(inputs_embeds):
|
|
||||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
|
||||||
outputs_embeds.append(out_emb)
|
|
||||||
return outputs_embeds
|
|
||||||
|
|
||||||
# Apply gradient checkpointing to final norm if enabled
|
|
||||||
if use_gradient_checkpointing:
|
|
||||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
|
||||||
compute_final_norms,
|
|
||||||
inputs_embeds,
|
|
||||||
adarms_cond,
|
|
||||||
use_reentrant=False,
|
|
||||||
preserve_rng_state=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
|
||||||
|
|
||||||
prefix_output = outputs_embeds[0]
|
|
||||||
suffix_output = outputs_embeds[1]
|
|
||||||
prefix_past_key_values = None
|
|
||||||
|
|
||||||
return [prefix_output, suffix_output], prefix_past_key_values
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
|
|
||||||
|
|
||||||
def resize_with_pad_torch(
|
|
||||||
images: torch.Tensor,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
mode: str = "bilinear",
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
|
||||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
|
||||||
height: Target height
|
|
||||||
width: Target width
|
|
||||||
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resized and padded tensor with same shape format as input
|
|
||||||
"""
|
|
||||||
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
|
||||||
if images.shape[-1] <= 4: # Assume channels-last format
|
|
||||||
channels_last = True
|
|
||||||
# Convert to channels-first for torch operations
|
|
||||||
if images.dim() == 3:
|
|
||||||
images = images.unsqueeze(0) # Add batch dimension
|
|
||||||
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
|
||||||
else:
|
|
||||||
channels_last = False
|
|
||||||
if images.dim() == 3:
|
|
||||||
images = images.unsqueeze(0) # Add batch dimension
|
|
||||||
|
|
||||||
batch_size, channels, cur_height, cur_width = images.shape
|
|
||||||
|
|
||||||
# Calculate resize ratio
|
|
||||||
ratio = max(cur_width / width, cur_height / height)
|
|
||||||
resized_height = int(cur_height / ratio)
|
|
||||||
resized_width = int(cur_width / ratio)
|
|
||||||
|
|
||||||
# Resize
|
|
||||||
resized_images = F.interpolate(
|
|
||||||
images,
|
|
||||||
size=(resized_height, resized_width),
|
|
||||||
mode=mode,
|
|
||||||
align_corners=False if mode == "bilinear" else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle dtype-specific clipping
|
|
||||||
if images.dtype == torch.uint8:
|
|
||||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
|
||||||
elif images.dtype == torch.float32:
|
|
||||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
||||||
|
|
||||||
# Calculate padding
|
|
||||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
|
||||||
pad_h1 = pad_h0 + remainder_h
|
|
||||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
|
||||||
pad_w1 = pad_w0 + remainder_w
|
|
||||||
|
|
||||||
# Pad
|
|
||||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
|
||||||
padded_images = F.pad(
|
|
||||||
resized_images,
|
|
||||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
|
||||||
mode="constant",
|
|
||||||
value=constant_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert back to original format if needed
|
|
||||||
if channels_last:
|
|
||||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
||||||
if batch_size == 1 and images.shape[0] == 1:
|
|
||||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
|
||||||
|
|
||||||
return padded_images
|
|
||||||
@@ -1,471 +0,0 @@
|
|||||||
import copy
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
|
|
||||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
|
|
||||||
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
|
||||||
|
|
||||||
|
|
||||||
def get_safe_dtype(target_dtype, device_type):
|
|
||||||
"""Get a safe dtype for the given device type."""
|
|
||||||
if device_type == "cpu":
|
|
||||||
# CPU doesn't support bfloat16, use float32 instead
|
|
||||||
if target_dtype == torch.bfloat16:
|
|
||||||
return torch.float32
|
|
||||||
if target_dtype == torch.float64:
|
|
||||||
return torch.float64
|
|
||||||
return target_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
|
||||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
|
||||||
) -> Tensor:
|
|
||||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
|
||||||
if dimension % 2 != 0:
|
|
||||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
|
||||||
|
|
||||||
if time.ndim != 1:
|
|
||||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
|
||||||
|
|
||||||
dtype = get_safe_dtype(torch.float64, device.type)
|
|
||||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
|
||||||
period = min_period * (max_period / min_period) ** fraction
|
|
||||||
|
|
||||||
# Compute the outer product
|
|
||||||
scaling_factor = 1.0 / period * 2 * math.pi
|
|
||||||
sin_input = scaling_factor[None, :] * time[:, None]
|
|
||||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def sample_beta(alpha, beta, bsize, device):
|
|
||||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
|
||||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
|
||||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
|
||||||
return dist.sample((bsize,))
|
|
||||||
|
|
||||||
|
|
||||||
def make_att_2d_masks(pad_masks, att_masks):
|
|
||||||
"""Copied from big_vision.
|
|
||||||
|
|
||||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
|
||||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
|
||||||
setup several types of attention, for example:
|
|
||||||
|
|
||||||
[[1 1 1 1 1 1]]: pure causal attention.
|
|
||||||
|
|
||||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
|
||||||
themselves and the last 3 tokens have a causal attention. The first
|
|
||||||
entry could also be a 1 without changing behaviour.
|
|
||||||
|
|
||||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
|
||||||
block can attend all previous blocks and all tokens on the same block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
|
||||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
|
||||||
it and 0 where it shares the same attention mask as the previous token.
|
|
||||||
"""
|
|
||||||
if att_masks.ndim != 2:
|
|
||||||
raise ValueError(att_masks.ndim)
|
|
||||||
if pad_masks.ndim != 2:
|
|
||||||
raise ValueError(pad_masks.ndim)
|
|
||||||
|
|
||||||
cumsum = torch.cumsum(att_masks, dim=1)
|
|
||||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
|
||||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
|
||||||
return att_2d_masks & pad_2d_masks
|
|
||||||
|
|
||||||
|
|
||||||
class PI0Pytorch(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.pi05 = config.pi05
|
|
||||||
|
|
||||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
|
||||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
|
||||||
|
|
||||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
|
||||||
paligemma_config,
|
|
||||||
action_expert_config,
|
|
||||||
use_adarms=[False, True] if self.pi05 else [False, False],
|
|
||||||
precision=config.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
|
||||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
|
||||||
|
|
||||||
if self.pi05:
|
|
||||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
||||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
||||||
else:
|
|
||||||
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
|
||||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
|
||||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("high")
|
|
||||||
if config.pytorch_compile_mode is not None:
|
|
||||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
|
|
||||||
|
|
||||||
# Initialize gradient checkpointing flag
|
|
||||||
self.gradient_checkpointing_enabled = False
|
|
||||||
|
|
||||||
# The upstream OpenPI module verifies a site-package Transformers patch here.
|
|
||||||
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(self):
|
|
||||||
"""Enable gradient checkpointing for memory optimization."""
|
|
||||||
self.gradient_checkpointing_enabled = True
|
|
||||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
|
||||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
|
||||||
|
|
||||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
|
||||||
|
|
||||||
def gradient_checkpointing_disable(self):
|
|
||||||
"""Disable gradient checkpointing."""
|
|
||||||
self.gradient_checkpointing_enabled = False
|
|
||||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
|
||||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
|
||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
|
||||||
|
|
||||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
|
||||||
|
|
||||||
def is_gradient_checkpointing_enabled(self):
|
|
||||||
"""Check if gradient checkpointing is enabled."""
|
|
||||||
return self.gradient_checkpointing_enabled
|
|
||||||
|
|
||||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
|
||||||
"""Helper method to apply gradient checkpointing if enabled."""
|
|
||||||
if self.gradient_checkpointing_enabled and self.training:
|
|
||||||
return torch.utils.checkpoint.checkpoint(
|
|
||||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
|
||||||
)
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
|
||||||
"""Helper method to prepare 4D attention masks for transformer."""
|
|
||||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
|
||||||
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
|
||||||
|
|
||||||
def _preprocess_observation(self, observation, *, train=True):
|
|
||||||
"""Helper method to preprocess observation."""
|
|
||||||
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
|
||||||
return (
|
|
||||||
list(observation.images.values()),
|
|
||||||
list(observation.image_masks.values()),
|
|
||||||
observation.tokenized_prompt,
|
|
||||||
observation.tokenized_prompt_mask,
|
|
||||||
observation.state,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_noise(self, shape, device):
|
|
||||||
return torch.normal(
|
|
||||||
mean=0.0,
|
|
||||||
std=1.0,
|
|
||||||
size=shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample_time(self, bsize, device):
|
|
||||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
|
||||||
time = time_beta * 0.999 + 0.001
|
|
||||||
return time.to(dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
def embed_prefix(
|
|
||||||
self, images, img_masks, lang_tokens, lang_masks
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
|
||||||
for PaliGemma transformer processing.
|
|
||||||
"""
|
|
||||||
embs = []
|
|
||||||
pad_masks = []
|
|
||||||
att_masks = []
|
|
||||||
|
|
||||||
# Process images
|
|
||||||
for img, img_mask in zip(images, img_masks, strict=True):
|
|
||||||
|
|
||||||
def image_embed_func(img):
|
|
||||||
return self.paligemma_with_expert.embed_image(img)
|
|
||||||
|
|
||||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
|
||||||
|
|
||||||
bsize, num_img_embs = img_emb.shape[:2]
|
|
||||||
|
|
||||||
embs.append(img_emb)
|
|
||||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
|
||||||
|
|
||||||
# Create attention masks so that image tokens attend to each other
|
|
||||||
att_masks += [0] * num_img_embs
|
|
||||||
|
|
||||||
# Process language tokens
|
|
||||||
def lang_embed_func(lang_tokens):
|
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
|
||||||
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
|
|
||||||
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
|
|
||||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
|
|
||||||
return lang_emb
|
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
|
||||||
|
|
||||||
embs.append(lang_emb)
|
|
||||||
pad_masks.append(lang_masks)
|
|
||||||
|
|
||||||
# full attention between image and language inputs
|
|
||||||
num_lang_embs = lang_emb.shape[1]
|
|
||||||
att_masks += [0] * num_lang_embs
|
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
|
||||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
|
||||||
|
|
||||||
# Get batch size from the first dimension of the concatenated tensors
|
|
||||||
bsize = pad_masks.shape[0]
|
|
||||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
||||||
|
|
||||||
return embs, pad_masks, att_masks
|
|
||||||
|
|
||||||
def embed_suffix(self, state, noisy_actions, timestep):
|
|
||||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
|
||||||
embs = []
|
|
||||||
pad_masks = []
|
|
||||||
att_masks = []
|
|
||||||
|
|
||||||
if not self.pi05:
|
|
||||||
if self.state_proj.weight.dtype == torch.float32:
|
|
||||||
state = state.to(torch.float32)
|
|
||||||
|
|
||||||
# Embed state
|
|
||||||
def state_proj_func(state):
|
|
||||||
return self.state_proj(state)
|
|
||||||
|
|
||||||
state_emb = self._apply_checkpoint(state_proj_func, state)
|
|
||||||
|
|
||||||
embs.append(state_emb[:, None, :])
|
|
||||||
bsize = state_emb.shape[0]
|
|
||||||
device = state_emb.device
|
|
||||||
|
|
||||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
|
||||||
pad_masks.append(state_mask)
|
|
||||||
|
|
||||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
|
||||||
att_masks += [1]
|
|
||||||
|
|
||||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
|
||||||
time_emb = create_sinusoidal_pos_embedding(
|
|
||||||
timestep,
|
|
||||||
self.action_in_proj.out_features,
|
|
||||||
min_period=4e-3,
|
|
||||||
max_period=4.0,
|
|
||||||
device=timestep.device,
|
|
||||||
)
|
|
||||||
time_emb = time_emb.type(dtype=timestep.dtype)
|
|
||||||
|
|
||||||
# Fuse timestep + action information using an MLP
|
|
||||||
def action_proj_func(noisy_actions):
|
|
||||||
return self.action_in_proj(noisy_actions)
|
|
||||||
|
|
||||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
|
||||||
|
|
||||||
if not self.pi05:
|
|
||||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
|
||||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
|
||||||
|
|
||||||
# Apply MLP layers
|
|
||||||
def mlp_func(action_time_emb):
|
|
||||||
x = self.action_time_mlp_in(action_time_emb)
|
|
||||||
x = F.silu(x) # swish == silu
|
|
||||||
return self.action_time_mlp_out(x)
|
|
||||||
|
|
||||||
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
|
||||||
adarms_cond = None
|
|
||||||
else:
|
|
||||||
# time MLP (for adaRMS)
|
|
||||||
def time_mlp_func(time_emb):
|
|
||||||
x = self.time_mlp_in(time_emb)
|
|
||||||
x = F.silu(x) # swish == silu
|
|
||||||
x = self.time_mlp_out(x)
|
|
||||||
return F.silu(x)
|
|
||||||
|
|
||||||
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
|
||||||
action_time_emb = action_emb
|
|
||||||
adarms_cond = time_emb
|
|
||||||
|
|
||||||
# Add to input tokens
|
|
||||||
embs.append(action_time_emb)
|
|
||||||
|
|
||||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
|
||||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
|
||||||
pad_masks.append(action_time_mask)
|
|
||||||
|
|
||||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
|
||||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
|
||||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
|
||||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
|
||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
|
||||||
|
|
||||||
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
|
|
||||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
|
||||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
|
||||||
observation, train=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
|
||||||
u_t = noise - actions
|
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
||||||
images, img_masks, lang_tokens, lang_masks
|
|
||||||
)
|
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
|
||||||
if (
|
|
||||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
|
||||||
== torch.bfloat16
|
|
||||||
):
|
|
||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
|
||||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
|
||||||
|
|
||||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
|
||||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
|
||||||
|
|
||||||
# Prepare attention masks
|
|
||||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
|
||||||
|
|
||||||
# Apply gradient checkpointing if enabled
|
|
||||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
|
||||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
|
||||||
attention_mask=att_2d_masks_4d,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=[prefix_embs, suffix_embs],
|
|
||||||
use_cache=False,
|
|
||||||
adarms_cond=[None, adarms_cond],
|
|
||||||
)
|
|
||||||
return suffix_out
|
|
||||||
|
|
||||||
suffix_out = self._apply_checkpoint(
|
|
||||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
|
||||||
)
|
|
||||||
|
|
||||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# Apply gradient checkpointing to final action projection if enabled
|
|
||||||
def action_out_proj_func(suffix_out):
|
|
||||||
return self.action_out_proj(suffix_out)
|
|
||||||
|
|
||||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
|
||||||
|
|
||||||
return F.mse_loss(u_t, v_t, reduction="none")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
|
||||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
|
||||||
bsize = observation.state.shape[0]
|
|
||||||
if noise is None:
|
|
||||||
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
|
||||||
noise = self.sample_noise(actions_shape, device)
|
|
||||||
|
|
||||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
|
||||||
observation, train=False
|
|
||||||
)
|
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
|
||||||
images, img_masks, lang_tokens, lang_masks
|
|
||||||
)
|
|
||||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
|
||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
|
||||||
|
|
||||||
# Compute image and language key value cache
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
|
||||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
|
||||||
attention_mask=prefix_att_2d_masks_4d,
|
|
||||||
position_ids=prefix_position_ids,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds=[prefix_embs, None],
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
dt = -1.0 / num_steps
|
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
x_t = noise
|
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
|
||||||
while time >= -dt / 2:
|
|
||||||
expanded_time = time.expand(bsize)
|
|
||||||
v_t = self.denoise_step(
|
|
||||||
state,
|
|
||||||
prefix_pad_masks,
|
|
||||||
past_key_values,
|
|
||||||
x_t,
|
|
||||||
expanded_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Euler step - use new tensor assignment instead of in-place operation
|
|
||||||
x_t = x_t + dt * v_t
|
|
||||||
time += dt
|
|
||||||
return x_t
|
|
||||||
|
|
||||||
def denoise_step(
|
|
||||||
self,
|
|
||||||
state,
|
|
||||||
prefix_pad_masks,
|
|
||||||
past_key_values,
|
|
||||||
x_t,
|
|
||||||
timestep,
|
|
||||||
):
|
|
||||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
|
||||||
|
|
||||||
suffix_len = suffix_pad_masks.shape[1]
|
|
||||||
batch_size = prefix_pad_masks.shape[0]
|
|
||||||
prefix_len = prefix_pad_masks.shape[1]
|
|
||||||
|
|
||||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
|
||||||
|
|
||||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
|
||||||
|
|
||||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
|
||||||
|
|
||||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
|
||||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
|
||||||
|
|
||||||
# Prepare attention masks
|
|
||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
|
||||||
|
|
||||||
past_key_values = copy.deepcopy(past_key_values)
|
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
|
||||||
attention_mask=full_att_2d_masks_4d,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=[None, suffix_embs],
|
|
||||||
use_cache=False,
|
|
||||||
adarms_cond=[None, adarms_cond],
|
|
||||||
)
|
|
||||||
|
|
||||||
suffix_out = outputs_embeds[1]
|
|
||||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
|
||||||
return self.action_out_proj(suffix_out)
|
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
import logging
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
|
|
||||||
|
|
||||||
logger = logging.getLogger("openpi")
|
|
||||||
|
|
||||||
# Constants moved from model.py
|
|
||||||
IMAGE_KEYS = (
|
|
||||||
"base_0_rgb",
|
|
||||||
"left_wrist_0_rgb",
|
|
||||||
"right_wrist_0_rgb",
|
|
||||||
)
|
|
||||||
|
|
||||||
IMAGE_RESOLUTION = (224, 224)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_observation_pytorch(
|
|
||||||
observation,
|
|
||||||
*,
|
|
||||||
train: bool = False,
|
|
||||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
|
||||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
|
||||||
):
|
|
||||||
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
|
||||||
|
|
||||||
This function avoids complex type annotations that can cause torch.compile issues.
|
|
||||||
"""
|
|
||||||
if not set(image_keys).issubset(observation.images):
|
|
||||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
|
||||||
|
|
||||||
batch_shape = observation.state.shape[:-1]
|
|
||||||
|
|
||||||
out_images = {}
|
|
||||||
for key in image_keys:
|
|
||||||
image = observation.images[key]
|
|
||||||
|
|
||||||
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
|
|
||||||
# Handle both [B, C, H, W] and [B, H, W, C] formats
|
|
||||||
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
|
|
||||||
|
|
||||||
if is_channels_first:
|
|
||||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
|
||||||
image = image.permute(0, 2, 3, 1)
|
|
||||||
|
|
||||||
if image.shape[1:3] != image_resolution:
|
|
||||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
|
||||||
image = image_tools.resize_with_pad_torch(image, *image_resolution)
|
|
||||||
|
|
||||||
if train:
|
|
||||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
|
||||||
image = image / 2.0 + 0.5
|
|
||||||
|
|
||||||
# Apply PyTorch-based augmentations
|
|
||||||
if "wrist" not in key:
|
|
||||||
# Geometric augmentations for non-wrist cameras
|
|
||||||
height, width = image.shape[1:3]
|
|
||||||
|
|
||||||
# Random crop and resize
|
|
||||||
crop_height = int(height * 0.95)
|
|
||||||
crop_width = int(width * 0.95)
|
|
||||||
|
|
||||||
# Random crop
|
|
||||||
max_h = height - crop_height
|
|
||||||
max_w = width - crop_width
|
|
||||||
if max_h > 0 and max_w > 0:
|
|
||||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
|
||||||
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
|
||||||
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
|
||||||
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
|
||||||
|
|
||||||
# Resize back to original size
|
|
||||||
image = torch.nn.functional.interpolate(
|
|
||||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
|
||||||
size=(height, width),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
||||||
|
|
||||||
# Random rotation (small angles)
|
|
||||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
|
||||||
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
|
||||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
|
||||||
# Convert to radians
|
|
||||||
angle_rad = angle * torch.pi / 180.0
|
|
||||||
|
|
||||||
# Create rotation matrix
|
|
||||||
cos_a = torch.cos(angle_rad)
|
|
||||||
sin_a = torch.sin(angle_rad)
|
|
||||||
|
|
||||||
# Apply rotation using grid_sample
|
|
||||||
grid_x = torch.linspace(-1, 1, width, device=image.device)
|
|
||||||
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
|
||||||
|
|
||||||
# Create meshgrid
|
|
||||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
|
||||||
|
|
||||||
# Expand to batch dimension
|
|
||||||
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
|
||||||
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
|
|
||||||
|
|
||||||
# Apply rotation transformation
|
|
||||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
|
||||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
|
||||||
|
|
||||||
# Stack and reshape for grid_sample
|
|
||||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
|
||||||
|
|
||||||
image = torch.nn.functional.grid_sample(
|
|
||||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
|
||||||
grid,
|
|
||||||
mode="bilinear",
|
|
||||||
padding_mode="zeros",
|
|
||||||
align_corners=False,
|
|
||||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
|
||||||
|
|
||||||
# Color augmentations for all cameras
|
|
||||||
# Random brightness
|
|
||||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
|
||||||
brightness_factor = (
|
|
||||||
0.7 + torch.rand(1, device=image.device) * 0.6
|
|
||||||
) # Random factor between 0.7 and 1.3
|
|
||||||
image = image * brightness_factor
|
|
||||||
|
|
||||||
# Random contrast
|
|
||||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
|
||||||
contrast_factor = (
|
|
||||||
0.6 + torch.rand(1, device=image.device) * 0.8
|
|
||||||
) # Random factor between 0.6 and 1.4
|
|
||||||
mean = image.mean(dim=[1, 2, 3], keepdim=True)
|
|
||||||
image = (image - mean) * contrast_factor + mean
|
|
||||||
|
|
||||||
# Random saturation (convert to HSV, modify S, convert back)
|
|
||||||
# For simplicity, we'll just apply a random scaling to the color channels
|
|
||||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
|
||||||
saturation_factor = (
|
|
||||||
0.5 + torch.rand(1, device=image.device) * 1.0
|
|
||||||
) # Random factor between 0.5 and 1.5
|
|
||||||
gray = image.mean(dim=-1, keepdim=True)
|
|
||||||
image = gray + (image - gray) * saturation_factor
|
|
||||||
|
|
||||||
# Clamp values to [0, 1]
|
|
||||||
image = torch.clamp(image, 0, 1)
|
|
||||||
|
|
||||||
# Back to [-1, 1]
|
|
||||||
image = image * 2.0 - 1.0
|
|
||||||
|
|
||||||
# Convert back to [B, C, H, W] format if it was originally channels-first
|
|
||||||
if is_channels_first:
|
|
||||||
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
|
||||||
|
|
||||||
out_images[key] = image
|
|
||||||
|
|
||||||
# obtain mask
|
|
||||||
out_masks = {}
|
|
||||||
for key in out_images:
|
|
||||||
if key not in observation.image_masks:
|
|
||||||
# do not mask by default
|
|
||||||
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
|
|
||||||
else:
|
|
||||||
out_masks[key] = observation.image_masks[key]
|
|
||||||
|
|
||||||
# Create a simple object with the required attributes instead of using the complex Observation class
|
|
||||||
class SimpleProcessedObservation:
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
return SimpleProcessedObservation(
|
|
||||||
images=out_images,
|
|
||||||
image_masks=out_masks,
|
|
||||||
state=observation.state,
|
|
||||||
tokenized_prompt=observation.tokenized_prompt,
|
|
||||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
|
||||||
token_ar_mask=observation.token_ar_mask,
|
|
||||||
token_loss_mask=observation.token_loss_mask,
|
|
||||||
)
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
|
|
||||||
from lerobot.policies.pi05 import PI05Config # noqa: E402
|
|
||||||
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
|
||||||
assert_cache_stability,
|
|
||||||
assert_compiled_output_matches_eager,
|
|
||||||
assert_explain_has_no_graph_breaks,
|
|
||||||
benchmark_runtime,
|
|
||||||
make_compile_config,
|
|
||||||
reset_compile_state,
|
|
||||||
)
|
|
||||||
from tests.utils import require_cuda # noqa: E402
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_model(*, compile_model):
|
|
||||||
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
|
|
||||||
|
|
||||||
|
|
||||||
def _make_dummy_inputs(config):
|
|
||||||
device = torch.device("cuda")
|
|
||||||
common = {
|
|
||||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
|
||||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
|
||||||
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
|
||||||
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
|
||||||
}
|
|
||||||
forward_kwargs = {
|
|
||||||
**common,
|
|
||||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
||||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
||||||
"time": torch.rand(1, device=device),
|
|
||||||
}
|
|
||||||
sample_kwargs = {
|
|
||||||
**common,
|
|
||||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
||||||
"num_steps": config.num_inference_steps,
|
|
||||||
}
|
|
||||||
return forward_kwargs, sample_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
|
||||||
def test_pi05_torch_compile_forward_and_sample_actions():
|
|
||||||
if not hasattr(torch, "compile"):
|
|
||||||
pytest.skip("torch.compile is not available")
|
|
||||||
if not torch._dynamo.is_dynamo_supported():
|
|
||||||
pytest.skip("torch._dynamo is not supported on this platform")
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
eager_model = _make_model(compile_model=False)
|
|
||||||
torch.manual_seed(0)
|
|
||||||
compiled_model = _make_model(compile_model=True)
|
|
||||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
|
||||||
|
|
||||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
|
|
||||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
|
||||||
|
|
||||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
|
|
||||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
|
||||||
|
|
||||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
|
|
||||||
benchmark_runtime(
|
|
||||||
eager_model.sample_actions,
|
|
||||||
compiled_model.sample_actions,
|
|
||||||
sample_kwargs,
|
|
||||||
"pi05.sample_actions",
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
reset_compile_state()
|
|
||||||
del eager_model
|
|
||||||
del compiled_model
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
@@ -14,56 +14,52 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
|
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||||
|
|
||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# Skip if openpi or transformers is not available
|
||||||
|
pytest.importorskip("openpi")
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
# Skip this entire module in CI
|
||||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
|
||||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
|
||||||
assert_processor_inputs_match_lerobot,
|
|
||||||
clone_batch,
|
|
||||||
deterministic_openpi_forward_preprocess,
|
|
||||||
fix_reference_state_dict,
|
|
||||||
fixed_flow_sampling,
|
|
||||||
load_openpi_reference_state_dict,
|
|
||||||
make_openpi_observation_from_raw,
|
|
||||||
openpi_model_actions_from_raw,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||||
|
|
||||||
|
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||||
|
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||||
|
from transformers import AutoTokenizer # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||||
|
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||||
|
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||||
|
from lerobot.types import PolicyAction # noqa: E402
|
||||||
|
|
||||||
|
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||||
DUMMY_ACTION_DIM = 32
|
DUMMY_ACTION_DIM = 32
|
||||||
DUMMY_STATE_DIM = 32
|
DUMMY_STATE_DIM = 32
|
||||||
DUMMY_ACTION_HORIZON = 50
|
DUMMY_ACTION_HORIZON = 50
|
||||||
DUMMY_MAX_TOKEN_LEN = 200
|
DUMMY_MAX_TOKEN_LEN = 200
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||||
COMPILE_MODE = "default"
|
|
||||||
FORWARD_RTOL = 1e-4
|
|
||||||
FORWARD_ATOL = 1e-4
|
|
||||||
SAMPLE_RTOL = 1e-2
|
|
||||||
SAMPLE_ATOL = 5e-3
|
|
||||||
|
|
||||||
DUMMY_DATASET_STATS = {
|
DUMMY_DATASET_STATS = {
|
||||||
OBS_STATE: {
|
"observation.state": {
|
||||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||||
"std": torch.ones(DUMMY_STATE_DIM),
|
"std": torch.ones(DUMMY_STATE_DIM),
|
||||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||||
},
|
},
|
||||||
ACTION: {
|
"action": {
|
||||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||||
@@ -92,15 +88,6 @@ DUMMY_DATASET_STATS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def cleanup_cuda_after_test():
|
|
||||||
yield
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
class PI05BaseOriginalConfig:
|
class PI05BaseOriginalConfig:
|
||||||
action_dim: int = DUMMY_ACTION_DIM
|
action_dim: int = DUMMY_ACTION_DIM
|
||||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||||
@@ -109,163 +96,341 @@ class PI05BaseOriginalConfig:
|
|||||||
precision: str = "float32"
|
precision: str = "float32"
|
||||||
pi05: bool = True
|
pi05: bool = True
|
||||||
dtype: str = "float32"
|
dtype: str = "float32"
|
||||||
pytorch_compile_mode: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
def instantiate_lerobot_pi05(
|
||||||
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
|
from_pretrained: bool = False,
|
||||||
config.device = str(DEVICE)
|
) -> tuple[
|
||||||
config.dtype = "float32"
|
PI05Policy,
|
||||||
config.compile_model = compile_model
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
config.compile_mode = COMPILE_MODE
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
config.gradient_checkpointing = gradient_checkpointing
|
]:
|
||||||
|
if from_pretrained:
|
||||||
|
# Load the policy first
|
||||||
|
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
|
||||||
|
else:
|
||||||
|
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||||
|
policy = PI05Policy(config)
|
||||||
|
|
||||||
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
|
|
||||||
policy.to(DEVICE)
|
policy.to(DEVICE)
|
||||||
policy.config.device = str(DEVICE)
|
policy.config.device = DEVICE
|
||||||
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||||
return policy, preprocessor
|
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||||
|
)
|
||||||
|
return (policy, preprocessor, postprocessor)
|
||||||
|
|
||||||
|
|
||||||
def instantiate_original_pi05():
|
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||||
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
|
config = PI05BaseOriginalConfig()
|
||||||
|
policy = PI0Pytorch(config)
|
||||||
|
|
||||||
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
|
if from_pretrained:
|
||||||
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
|
try:
|
||||||
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
|
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
|
||||||
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名;LeRobot 的
|
|
||||||
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
|
# Download the model from HuggingFace Hub
|
||||||
# loader,所以这里手动复用同一份 tensor,避免把权重别名差异误判成模型差异。
|
import safetensors.torch
|
||||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
|
from huggingface_hub import snapshot_download
|
||||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
|
||||||
assert missing_keys == []
|
# Download the entire repository
|
||||||
assert unexpected_keys == []
|
if model_path and os.path.exists(model_path):
|
||||||
|
cache_dir = model_path
|
||||||
|
print(f"Using cached model from: {cache_dir}")
|
||||||
|
else:
|
||||||
|
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
|
||||||
|
print(f"Downloaded model to: {cache_dir}")
|
||||||
|
|
||||||
|
# Try to load safetensors format first
|
||||||
|
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||||
|
if os.path.exists(model_file):
|
||||||
|
state_dict = safetensors.torch.load_file(model_file)
|
||||||
|
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||||
|
|
||||||
|
# Load the state dict into the model
|
||||||
|
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if missing_keys:
|
||||||
|
print(f"Missing keys: {len(missing_keys)}")
|
||||||
|
if len(missing_keys) <= 5:
|
||||||
|
for key in missing_keys:
|
||||||
|
print(f" - {key}")
|
||||||
|
else:
|
||||||
|
for key in missing_keys[:5]:
|
||||||
|
print(f" - {key}")
|
||||||
|
print(f" ... and {len(missing_keys) - 5} more")
|
||||||
|
|
||||||
|
if unexpected_keys:
|
||||||
|
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||||
|
if len(unexpected_keys) <= 5:
|
||||||
|
for key in unexpected_keys:
|
||||||
|
print(f" - {key}")
|
||||||
|
else:
|
||||||
|
for key in unexpected_keys[:5]:
|
||||||
|
print(f" - {key}")
|
||||||
|
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||||
|
|
||||||
|
if not missing_keys and not unexpected_keys:
|
||||||
|
print("All pretrained weights loaded successfully!")
|
||||||
|
else:
|
||||||
|
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load pretrained weights: {e}")
|
||||||
|
print(" Using randomly initialized weights...")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
policy.to(DEVICE)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_data():
|
def create_dummy_data():
|
||||||
batch_size = 2
|
batch_size = 2 # Reduce batch size for testing
|
||||||
|
device = DEVICE
|
||||||
|
|
||||||
|
# Use the exact same prompt for both implementations
|
||||||
prompt = "Pick up the red block and place it in the bin"
|
prompt = "Pick up the red block and place it in the bin"
|
||||||
return {
|
|
||||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
batch = {
|
||||||
ACTION: torch.randn(
|
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
"action": torch.randn(
|
||||||
|
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
|
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||||
"observation.images.base_0_rgb": torch.rand(
|
"observation.images.base_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
|
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||||
"task": [prompt for _ in range(batch_size)],
|
"task": [prompt for _ in range(batch_size)],
|
||||||
}
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
|
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||||
torch.manual_seed(0)
|
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||||
raw_batch = create_dummy_data()
|
# Get the tokenized language from LeRobot's internal method
|
||||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||||
openpi_observation = make_openpi_observation_from_raw(
|
|
||||||
raw_batch,
|
# Get the preprocessed images from LeRobot's internal method
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||||
pi05=True,
|
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||||
)
|
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||||
openpi_actions = openpi_model_actions_from_raw(
|
|
||||||
raw_batch,
|
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
|
||||||
pi05=True,
|
|
||||||
)
|
|
||||||
assert_processor_inputs_match_lerobot(
|
|
||||||
lerobot_pi05,
|
|
||||||
lerobot_batch,
|
|
||||||
openpi_observation,
|
|
||||||
compare_state=False,
|
|
||||||
)
|
|
||||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
|
||||||
noise = torch.randn(
|
|
||||||
batch_size,
|
|
||||||
DUMMY_ACTION_HORIZON,
|
|
||||||
DUMMY_ACTION_DIM,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=DEVICE,
|
|
||||||
)
|
|
||||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
|
||||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
|
||||||
|
|
||||||
|
|
||||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
class PI05Observation:
|
||||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
|
"""Observation class that matches the original OpenPI format."""
|
||||||
compile_model=compile_model,
|
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
def __init__(
|
||||||
)
|
self,
|
||||||
original_pi05 = instantiate_original_pi05()
|
state,
|
||||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
images,
|
||||||
lerobot_pi05,
|
image_masks,
|
||||||
lerobot_preprocessor,
|
tokenized_prompt,
|
||||||
|
tokenized_prompt_mask,
|
||||||
|
token_ar_mask,
|
||||||
|
token_loss_mask,
|
||||||
|
):
|
||||||
|
self.state = state
|
||||||
|
self.images = images
|
||||||
|
self.image_masks = image_masks
|
||||||
|
self.tokenized_prompt = tokenized_prompt
|
||||||
|
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||||
|
self.token_ar_mask = token_ar_mask
|
||||||
|
self.token_loss_mask = token_loss_mask
|
||||||
|
|
||||||
|
|
||||||
|
def create_original_observation_with_openpi_preprocessing(batch):
|
||||||
|
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||||
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
device = batch["observation.state"].device
|
||||||
|
|
||||||
|
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||||
|
|
||||||
|
# Get task description (pi05 processor handles all text formatting)
|
||||||
|
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||||
|
if isinstance(tasks, str):
|
||||||
|
tasks = [tasks] * batch_size
|
||||||
|
elif len(tasks) == 1:
|
||||||
|
tasks = tasks * batch_size
|
||||||
|
|
||||||
|
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||||
|
state = batch["observation.state"]
|
||||||
|
state = deepcopy(state)
|
||||||
|
|
||||||
|
# Prepare state (pad to max_state_dim)
|
||||||
|
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||||
|
|
||||||
|
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||||
|
|
||||||
|
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||||
|
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||||
|
state_np = state.cpu().numpy()
|
||||||
|
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||||
|
|
||||||
|
# Create pi05-formatted prompts that include state information
|
||||||
|
full_prompts = []
|
||||||
|
for i, task in enumerate(tasks):
|
||||||
|
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||||
|
state_str = " ".join(map(str, discretized_states[i]))
|
||||||
|
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||||
|
full_prompts.append(full_prompt)
|
||||||
|
|
||||||
|
# Tokenize with max_length padding to match OpenPI's expected format
|
||||||
|
tokenized = tokenizer(
|
||||||
|
full_prompts,
|
||||||
|
padding="max_length",
|
||||||
|
padding_side="right",
|
||||||
|
truncation=True,
|
||||||
|
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||||
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
if gradient_checkpointing:
|
lang_tokens = tokenized["input_ids"].to(device)
|
||||||
lerobot_pi05.train()
|
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||||
else:
|
|
||||||
lerobot_pi05.eval()
|
|
||||||
original_pi05.eval()
|
|
||||||
|
|
||||||
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
|
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||||
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
|
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||||
with deterministic_openpi_forward_preprocess(original_pi05):
|
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||||
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
|
|
||||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
|
||||||
|
|
||||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||||
|
image_dict = {
|
||||||
|
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||||
|
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||||
|
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create image masks (all ones for real images)
|
||||||
|
image_masks_dict = {}
|
||||||
|
for key in image_dict:
|
||||||
|
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
# Create raw observation object (before preprocessing)
|
||||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
|
raw_observation = PI05Observation(
|
||||||
original_pi05 = instantiate_original_pi05()
|
state=batch["observation.state"],
|
||||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
images=image_dict,
|
||||||
lerobot_pi05,
|
image_masks=image_masks_dict,
|
||||||
lerobot_preprocessor,
|
tokenized_prompt=lang_tokens,
|
||||||
|
tokenized_prompt_mask=lang_masks,
|
||||||
|
token_ar_mask=token_ar_mask,
|
||||||
|
token_loss_mask=token_loss_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
lerobot_pi05.eval()
|
# Now use OpenPI's preprocessing
|
||||||
original_pi05.eval()
|
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||||
|
|
||||||
|
return processed_obs
|
||||||
|
|
||||||
|
|
||||||
|
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||||
|
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||||
|
_batch_size = batch["observation.state"].shape[0]
|
||||||
|
_device = batch["observation.state"].device
|
||||||
|
|
||||||
|
# Extract the exact same processed inputs that LeRobot uses
|
||||||
|
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||||
|
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert images list to dict with original OpenPI keys
|
||||||
|
image_dict = {
|
||||||
|
"base_0_rgb": images[0],
|
||||||
|
"left_wrist_0_rgb": images[1],
|
||||||
|
"right_wrist_0_rgb": images[2],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert image masks list to dict with original OpenPI keys
|
||||||
|
image_masks_dict = {
|
||||||
|
"base_0_rgb": img_masks[0],
|
||||||
|
"left_wrist_0_rgb": img_masks[1],
|
||||||
|
"right_wrist_0_rgb": img_masks[2],
|
||||||
|
}
|
||||||
|
|
||||||
|
return PI05Observation(
|
||||||
|
state=batch["observation.state"],
|
||||||
|
images=image_dict,
|
||||||
|
image_masks=image_masks_dict,
|
||||||
|
tokenized_prompt=lang_tokens,
|
||||||
|
tokenized_prompt_mask=lang_masks,
|
||||||
|
token_ar_mask=token_ar_mask,
|
||||||
|
token_loss_mask=token_loss_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pi05_original_vs_lerobot():
|
||||||
|
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||||
|
print("Initializing models...")
|
||||||
|
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||||
|
from_pretrained=True
|
||||||
|
) # Load pretrained LeRobot model
|
||||||
|
original_pi0 = instantiate_original_pi05(
|
||||||
|
from_pretrained=True
|
||||||
|
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||||
|
|
||||||
|
print("Creating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
batch_lerobot = deepcopy(batch)
|
||||||
|
|
||||||
|
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||||
|
print("\nTest each model with its own preprocessing")
|
||||||
|
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||||
|
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||||
|
|
||||||
|
print(f"Task prompt: '{batch['task'][0]}'")
|
||||||
|
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||||
|
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||||
|
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||||
|
|
||||||
|
print("Testing OpenPI with own preprocessing...")
|
||||||
|
original_pi0.eval()
|
||||||
|
torch.manual_seed(42) # Set seed for reproducibility
|
||||||
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||||
|
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
openpi_actions = original_pi0.sample_actions(
|
||||||
openpi_actions = original_pi05.sample_actions(
|
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||||
device=DEVICE,
|
|
||||||
observation=openpi_observation,
|
|
||||||
noise=noise,
|
|
||||||
num_steps=10,
|
|
||||||
)
|
)
|
||||||
|
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||||
|
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||||
|
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||||
|
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||||
|
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||||
|
|
||||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
print("Testing LeRobot with own preprocessing...")
|
||||||
|
lerobot_pi05.eval()
|
||||||
|
torch.manual_seed(42) # Set the same seed
|
||||||
|
|
||||||
|
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||||
|
with torch.no_grad():
|
||||||
|
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||||
|
batch_lerobot_processed
|
||||||
|
) # batch_size, n_action_steps, action_dim
|
||||||
|
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||||
|
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||||
|
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||||
|
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||||
|
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||||
|
|
||||||
def test_pi05_forward_matches_openpi():
|
print("\nComparing end-to-end implementations:")
|
||||||
assert_forward_matches()
|
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||||
|
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||||
|
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||||
|
|
||||||
|
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||||
def test_pi05_sample_actions_match_openpi():
|
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||||
assert_sample_actions_match_openpi()
|
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||||
|
|
||||||
|
|
||||||
def test_pi05_gradient_checkpointing_forward_matches_openpi():
|
|
||||||
assert_forward_matches(gradient_checkpointing=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi05_compile_forward_matches_openpi():
|
|
||||||
assert_forward_matches(compile_model=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi05_compile_sample_actions_match_openpi():
|
|
||||||
assert_sample_actions_match_openpi(compile_model=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
|
|
||||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
|
||||||
|
|||||||
@@ -1,99 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
|
|
||||||
from lerobot.policies.pi0 import PI0Config # noqa: E402
|
|
||||||
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
|
||||||
assert_cache_stability,
|
|
||||||
assert_compiled_output_matches_eager,
|
|
||||||
assert_explain_has_no_graph_breaks,
|
|
||||||
benchmark_runtime,
|
|
||||||
make_compile_config,
|
|
||||||
reset_compile_state,
|
|
||||||
)
|
|
||||||
from tests.utils import require_cuda # noqa: E402
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_model(*, compile_model):
|
|
||||||
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
|
|
||||||
|
|
||||||
|
|
||||||
def _make_dummy_inputs(config):
|
|
||||||
device = torch.device("cuda")
|
|
||||||
common = {
|
|
||||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
|
||||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
|
||||||
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
|
||||||
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
|
||||||
"state": torch.randn(1, config.max_state_dim, device=device),
|
|
||||||
}
|
|
||||||
forward_kwargs = {
|
|
||||||
**common,
|
|
||||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
||||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
||||||
"time": torch.rand(1, device=device),
|
|
||||||
}
|
|
||||||
sample_kwargs = {
|
|
||||||
**common,
|
|
||||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
||||||
"num_steps": config.num_inference_steps,
|
|
||||||
}
|
|
||||||
return forward_kwargs, sample_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
|
||||||
def test_pi0_torch_compile_forward_and_sample_actions():
|
|
||||||
if not hasattr(torch, "compile"):
|
|
||||||
pytest.skip("torch.compile is not available")
|
|
||||||
if not torch._dynamo.is_dynamo_supported():
|
|
||||||
pytest.skip("torch._dynamo is not supported on this platform")
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
eager_model = _make_model(compile_model=False)
|
|
||||||
torch.manual_seed(0)
|
|
||||||
compiled_model = _make_model(compile_model=True)
|
|
||||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
|
||||||
|
|
||||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
|
|
||||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
|
||||||
|
|
||||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
|
|
||||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
|
||||||
|
|
||||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
|
|
||||||
benchmark_runtime(
|
|
||||||
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
reset_compile_state()
|
|
||||||
del eager_model
|
|
||||||
del compiled_model
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
@@ -14,56 +14,51 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
|
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||||
|
|
||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# Skip if openpi or transformers is not available
|
||||||
|
pytest.importorskip("openpi")
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
# Skip this entire module in CI
|
||||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
|
||||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
|
||||||
assert_processor_inputs_match_lerobot,
|
|
||||||
clone_batch,
|
|
||||||
deterministic_openpi_forward_preprocess,
|
|
||||||
fix_reference_state_dict,
|
|
||||||
fixed_flow_sampling,
|
|
||||||
load_openpi_reference_state_dict,
|
|
||||||
make_openpi_observation_from_raw,
|
|
||||||
openpi_model_actions_from_raw,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||||
|
|
||||||
|
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||||
|
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||||
|
from transformers import AutoTokenizer # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||||
|
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||||
|
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||||
|
from lerobot.types import PolicyAction # noqa: E402
|
||||||
|
|
||||||
|
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||||
DUMMY_ACTION_DIM = 32
|
DUMMY_ACTION_DIM = 32
|
||||||
DUMMY_STATE_DIM = 32
|
DUMMY_STATE_DIM = 32
|
||||||
DUMMY_ACTION_HORIZON = 50
|
DUMMY_ACTION_HORIZON = 50
|
||||||
DUMMY_MAX_TOKEN_LEN = 48
|
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||||
COMPILE_MODE = "default"
|
|
||||||
FORWARD_RTOL = 1e-4
|
|
||||||
FORWARD_ATOL = 1e-4
|
|
||||||
SAMPLE_RTOL = 1e-2
|
|
||||||
SAMPLE_ATOL = 5e-3
|
|
||||||
|
|
||||||
DUMMY_DATASET_STATS = {
|
DUMMY_DATASET_STATS = {
|
||||||
OBS_STATE: {
|
"observation.state": {
|
||||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||||
"std": torch.ones(DUMMY_STATE_DIM),
|
"std": torch.ones(DUMMY_STATE_DIM),
|
||||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||||
},
|
},
|
||||||
ACTION: {
|
"action": {
|
||||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||||
@@ -92,15 +87,6 @@ DUMMY_DATASET_STATS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def cleanup_cuda_after_test():
|
|
||||||
yield
|
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
class PI0BaseOriginalConfig:
|
class PI0BaseOriginalConfig:
|
||||||
action_dim: int = DUMMY_ACTION_DIM
|
action_dim: int = DUMMY_ACTION_DIM
|
||||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||||
@@ -109,156 +95,333 @@ class PI0BaseOriginalConfig:
|
|||||||
precision: str = "float32"
|
precision: str = "float32"
|
||||||
pi05: bool = False
|
pi05: bool = False
|
||||||
dtype: str = "float32"
|
dtype: str = "float32"
|
||||||
pytorch_compile_mode: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
def instantiate_lerobot_pi0(
|
||||||
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
|
from_pretrained: bool = False,
|
||||||
config.device = str(DEVICE)
|
) -> tuple[
|
||||||
config.dtype = "float32"
|
PI0Policy,
|
||||||
config.compile_model = compile_model
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
config.compile_mode = COMPILE_MODE
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
config.gradient_checkpointing = gradient_checkpointing
|
]:
|
||||||
|
if from_pretrained:
|
||||||
|
# Load the policy first
|
||||||
|
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
|
||||||
|
else:
|
||||||
|
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||||
|
policy = PI0Policy(config)
|
||||||
|
|
||||||
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
|
|
||||||
policy.to(DEVICE)
|
policy.to(DEVICE)
|
||||||
policy.config.device = str(DEVICE)
|
policy.config.device = DEVICE
|
||||||
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||||
return policy, preprocessor
|
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||||
|
)
|
||||||
|
return (policy, preprocessor, postprocessor)
|
||||||
|
|
||||||
|
|
||||||
def instantiate_original_pi0():
|
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||||
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
|
config = PI0BaseOriginalConfig()
|
||||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
|
policy = PI0Pytorch(config)
|
||||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
|
||||||
assert missing_keys == []
|
if from_pretrained:
|
||||||
assert unexpected_keys == []
|
try:
|
||||||
|
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
|
||||||
|
|
||||||
|
# Download the model from HuggingFace Hub
|
||||||
|
import safetensors.torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download the entire repository
|
||||||
|
if model_path and os.path.exists(model_path):
|
||||||
|
cache_dir = model_path
|
||||||
|
print(f"Using cached model from: {cache_dir}")
|
||||||
|
else:
|
||||||
|
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
|
||||||
|
print(f"Downloaded model to: {cache_dir}")
|
||||||
|
|
||||||
|
# Try to load safetensors format first
|
||||||
|
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||||
|
if os.path.exists(model_file):
|
||||||
|
state_dict = safetensors.torch.load_file(model_file)
|
||||||
|
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||||
|
|
||||||
|
# Load the state dict into the model
|
||||||
|
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if missing_keys:
|
||||||
|
print(f"Missing keys: {len(missing_keys)}")
|
||||||
|
if len(missing_keys) <= 5:
|
||||||
|
for key in missing_keys:
|
||||||
|
print(f" - {key}")
|
||||||
|
else:
|
||||||
|
for key in missing_keys[:5]:
|
||||||
|
print(f" - {key}")
|
||||||
|
print(f" ... and {len(missing_keys) - 5} more")
|
||||||
|
|
||||||
|
if unexpected_keys:
|
||||||
|
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||||
|
if len(unexpected_keys) <= 5:
|
||||||
|
for key in unexpected_keys:
|
||||||
|
print(f" - {key}")
|
||||||
|
else:
|
||||||
|
for key in unexpected_keys[:5]:
|
||||||
|
print(f" - {key}")
|
||||||
|
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||||
|
|
||||||
|
if not missing_keys and not unexpected_keys:
|
||||||
|
print("All pretrained weights loaded successfully!")
|
||||||
|
else:
|
||||||
|
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load pretrained weights: {e}")
|
||||||
|
print(" Using randomly initialized weights...")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
policy.to(DEVICE)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_data():
|
def create_dummy_data():
|
||||||
batch_size = 2
|
batch_size = 2 # Reduce batch size for testing
|
||||||
|
device = DEVICE
|
||||||
|
|
||||||
|
# Use the exact same prompt for both implementations
|
||||||
prompt = "Pick up the red block and place it in the bin"
|
prompt = "Pick up the red block and place it in the bin"
|
||||||
return {
|
|
||||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
batch = {
|
||||||
ACTION: torch.randn(
|
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
"action": torch.randn(
|
||||||
|
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
|
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||||
"observation.images.base_0_rgb": torch.rand(
|
"observation.images.base_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
|
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||||
"task": [prompt for _ in range(batch_size)],
|
"task": [prompt for _ in range(batch_size)],
|
||||||
}
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
|
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||||
torch.manual_seed(0)
|
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||||
raw_batch = create_dummy_data()
|
# Get the tokenized language from LeRobot's internal method
|
||||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||||
openpi_observation = make_openpi_observation_from_raw(
|
|
||||||
raw_batch,
|
# Get the preprocessed images from LeRobot's internal method
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||||
pi05=False,
|
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||||
)
|
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||||
openpi_actions = openpi_model_actions_from_raw(
|
|
||||||
raw_batch,
|
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
|
||||||
pi05=False,
|
|
||||||
)
|
|
||||||
assert_processor_inputs_match_lerobot(
|
|
||||||
lerobot_pi0,
|
|
||||||
lerobot_batch,
|
|
||||||
openpi_observation,
|
|
||||||
compare_state=True,
|
|
||||||
)
|
|
||||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
|
||||||
noise = torch.randn(
|
|
||||||
batch_size,
|
|
||||||
DUMMY_ACTION_HORIZON,
|
|
||||||
DUMMY_ACTION_DIM,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=DEVICE,
|
|
||||||
)
|
|
||||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
|
||||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
|
||||||
|
|
||||||
|
|
||||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
class PI0Observation:
|
||||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
|
"""Observation class that matches the original OpenPI format."""
|
||||||
compile_model=compile_model,
|
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
|
||||||
)
|
|
||||||
original_pi0 = instantiate_original_pi0()
|
|
||||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
|
||||||
lerobot_pi0,
|
|
||||||
lerobot_preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if gradient_checkpointing:
|
def __init__(
|
||||||
lerobot_pi0.train()
|
self,
|
||||||
|
state,
|
||||||
|
images,
|
||||||
|
image_masks,
|
||||||
|
tokenized_prompt,
|
||||||
|
tokenized_prompt_mask,
|
||||||
|
token_ar_mask,
|
||||||
|
token_loss_mask,
|
||||||
|
):
|
||||||
|
self.state = state
|
||||||
|
self.images = images
|
||||||
|
self.image_masks = image_masks
|
||||||
|
self.tokenized_prompt = tokenized_prompt
|
||||||
|
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||||
|
self.token_ar_mask = token_ar_mask
|
||||||
|
self.token_loss_mask = token_loss_mask
|
||||||
|
|
||||||
|
|
||||||
|
def create_original_observation_with_openpi_preprocessing(batch):
|
||||||
|
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||||
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
device = batch["observation.state"].device
|
||||||
|
|
||||||
|
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||||
|
|
||||||
|
# Get task description
|
||||||
|
if "task" in batch:
|
||||||
|
tasks = batch["task"]
|
||||||
|
if isinstance(tasks, str):
|
||||||
|
# Single string: add newline if not present, then convert to list
|
||||||
|
if not tasks.endswith("\n"):
|
||||||
|
tasks = f"{tasks}\n"
|
||||||
|
tasks = [tasks]
|
||||||
|
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||||
|
# List of strings: add newline to each if not present
|
||||||
|
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||||
|
if len(tasks) == 1:
|
||||||
|
# Expand to batch size
|
||||||
|
tasks = tasks * batch_size
|
||||||
|
if len(tasks) != batch_size:
|
||||||
|
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||||
|
# If task is neither string nor list of strings, leave unchanged
|
||||||
else:
|
else:
|
||||||
lerobot_pi0.eval()
|
# Default task if not provided
|
||||||
original_pi0.eval()
|
tasks = ["Pick up the object\n"] * batch_size
|
||||||
|
|
||||||
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
|
# Tokenize with max_length padding to match OpenPI's expected format
|
||||||
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
|
tokenized = tokenizer(
|
||||||
with deterministic_openpi_forward_preprocess(original_pi0):
|
tasks,
|
||||||
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
|
padding="max_length",
|
||||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
padding_side="right",
|
||||||
|
truncation=True,
|
||||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||||
|
return_tensors="pt",
|
||||||
|
|
||||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
|
||||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
|
|
||||||
original_pi0 = instantiate_original_pi0()
|
|
||||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
|
||||||
lerobot_pi0,
|
|
||||||
lerobot_preprocessor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
lerobot_pi0.eval()
|
lang_tokens = tokenized["input_ids"].to(device)
|
||||||
|
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||||
|
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||||
|
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||||
|
image_dict = {
|
||||||
|
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||||
|
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||||
|
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create image masks (all ones for real images)
|
||||||
|
image_masks_dict = {}
|
||||||
|
for key in image_dict:
|
||||||
|
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
# Create raw observation object (before preprocessing)
|
||||||
|
raw_observation = PI0Observation(
|
||||||
|
state=batch["observation.state"],
|
||||||
|
images=image_dict,
|
||||||
|
image_masks=image_masks_dict,
|
||||||
|
tokenized_prompt=lang_tokens,
|
||||||
|
tokenized_prompt_mask=lang_masks,
|
||||||
|
token_ar_mask=token_ar_mask,
|
||||||
|
token_loss_mask=token_loss_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now use OpenPI's preprocessing
|
||||||
|
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||||
|
|
||||||
|
return processed_obs
|
||||||
|
|
||||||
|
|
||||||
|
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||||
|
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||||
|
_batch_size = batch["observation.state"].shape[0]
|
||||||
|
_device = batch["observation.state"].device
|
||||||
|
|
||||||
|
# Extract the exact same processed inputs that LeRobot uses
|
||||||
|
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||||
|
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert images list to dict with original OpenPI keys
|
||||||
|
image_dict = {
|
||||||
|
"base_0_rgb": images[0],
|
||||||
|
"left_wrist_0_rgb": images[1],
|
||||||
|
"right_wrist_0_rgb": images[2],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert image masks list to dict with original OpenPI keys
|
||||||
|
image_masks_dict = {
|
||||||
|
"base_0_rgb": img_masks[0],
|
||||||
|
"left_wrist_0_rgb": img_masks[1],
|
||||||
|
"right_wrist_0_rgb": img_masks[2],
|
||||||
|
}
|
||||||
|
|
||||||
|
return PI0Observation(
|
||||||
|
state=batch["observation.state"],
|
||||||
|
images=image_dict,
|
||||||
|
image_masks=image_masks_dict,
|
||||||
|
tokenized_prompt=lang_tokens,
|
||||||
|
tokenized_prompt_mask=lang_masks,
|
||||||
|
token_ar_mask=token_ar_mask,
|
||||||
|
token_loss_mask=token_loss_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pi0_original_vs_lerobot():
|
||||||
|
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||||
|
print("Initializing models...")
|
||||||
|
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||||
|
from_pretrained=True
|
||||||
|
) # Load pretrained LeRobot model
|
||||||
|
original_pi0 = instantiate_original_pi0(
|
||||||
|
from_pretrained=True
|
||||||
|
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||||
|
|
||||||
|
print("Creating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
batch_lerobot = deepcopy(batch)
|
||||||
|
|
||||||
|
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||||
|
print("\nTest each model with its own preprocessing")
|
||||||
|
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||||
|
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||||
|
|
||||||
|
print(f"Task prompt: '{batch['task'][0]}'")
|
||||||
|
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||||
|
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||||
|
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||||
|
|
||||||
|
print("Testing OpenPI with own preprocessing...")
|
||||||
original_pi0.eval()
|
original_pi0.eval()
|
||||||
|
torch.manual_seed(42) # Set seed for reproducibility
|
||||||
|
batch_size = batch["observation.state"].shape[0]
|
||||||
|
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||||
|
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
|
||||||
openpi_actions = original_pi0.sample_actions(
|
openpi_actions = original_pi0.sample_actions(
|
||||||
device=DEVICE,
|
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||||
observation=openpi_observation,
|
|
||||||
noise=noise,
|
|
||||||
num_steps=10,
|
|
||||||
)
|
)
|
||||||
|
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||||
|
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||||
|
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||||
|
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||||
|
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||||
|
|
||||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
print("Testing LeRobot with own preprocessing...")
|
||||||
|
lerobot_pi0.eval()
|
||||||
|
torch.manual_seed(42) # Set the same seed
|
||||||
|
|
||||||
|
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||||
|
with torch.no_grad():
|
||||||
|
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||||
|
batch_lerobot_processed
|
||||||
|
) # batch_size, n_action_steps, action_dim
|
||||||
|
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||||
|
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||||
|
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||||
|
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||||
|
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||||
|
|
||||||
def test_pi0_forward_matches_openpi():
|
print("\nComparing end-to-end implementations:")
|
||||||
assert_forward_matches()
|
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||||
|
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||||
|
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||||
|
|
||||||
|
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||||
def test_pi0_sample_actions_match_openpi():
|
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||||
assert_sample_actions_match_openpi()
|
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||||
|
|
||||||
|
|
||||||
def test_pi0_gradient_checkpointing_forward_matches_openpi():
|
|
||||||
assert_forward_matches(gradient_checkpointing=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi0_compile_forward_matches_openpi():
|
|
||||||
assert_forward_matches(compile_model=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi0_compile_sample_actions_match_openpi():
|
|
||||||
assert_sample_actions_match_openpi(compile_model=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
|
|
||||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Utilities shared by PI0/PI05 policy tests."""
|
|
||||||
@@ -1,291 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Iterator
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import lru_cache
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from lerobot.utils.constants import (
|
|
||||||
ACTION,
|
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
|
||||||
OBS_LANGUAGE_TOKENS,
|
|
||||||
OBS_STATE,
|
|
||||||
)
|
|
||||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
|
|
||||||
|
|
||||||
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
|
||||||
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OpenPIObservation:
|
|
||||||
state: torch.Tensor
|
|
||||||
images: dict[str, torch.Tensor]
|
|
||||||
image_masks: dict[str, torch.Tensor]
|
|
||||||
tokenized_prompt: torch.Tensor
|
|
||||||
tokenized_prompt_mask: torch.Tensor
|
|
||||||
token_ar_mask: torch.Tensor
|
|
||||||
token_loss_mask: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def paligemma_tokenizer():
|
|
||||||
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
def clone_batch(batch: dict) -> dict:
|
|
||||||
return {
|
|
||||||
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
|
|
||||||
if tensor.shape[-1] > target_dim:
|
|
||||||
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
|
|
||||||
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
|
|
||||||
|
|
||||||
|
|
||||||
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
|
|
||||||
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
|
|
||||||
return (tensor - mean) / (std + 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
||||||
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
|
|
||||||
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
|
|
||||||
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
|
|
||||||
return 2.0 * (tensor - q01) / denom - 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def openpi_model_state_from_raw(
|
|
||||||
batch: dict[str, torch.Tensor],
|
|
||||||
*,
|
|
||||||
action_dim: int,
|
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
|
||||||
pi05: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
state = batch[OBS_STATE].to(dtype=torch.float32)
|
|
||||||
if pi05:
|
|
||||||
state = quantile_normalize(state, dataset_stats[OBS_STATE])
|
|
||||||
else:
|
|
||||||
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
|
|
||||||
return pad_last_dim(state, action_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def openpi_model_actions_from_raw(
|
|
||||||
batch: dict[str, torch.Tensor],
|
|
||||||
*,
|
|
||||||
action_dim: int,
|
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
|
||||||
pi05: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
actions = batch[ACTION].to(dtype=torch.float32)
|
|
||||||
if pi05:
|
|
||||||
actions = quantile_normalize(actions, dataset_stats[ACTION])
|
|
||||||
else:
|
|
||||||
actions = mean_std_normalize(actions, dataset_stats[ACTION])
|
|
||||||
return pad_last_dim(actions, action_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
|
|
||||||
tasks = batch.get("task")
|
|
||||||
if tasks is None:
|
|
||||||
raise ValueError("The parity batch must include a task prompt.")
|
|
||||||
if isinstance(tasks, str):
|
|
||||||
return [tasks] * batch_size
|
|
||||||
if len(tasks) == 1:
|
|
||||||
return [tasks[0]] * batch_size
|
|
||||||
if len(tasks) != batch_size:
|
|
||||||
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
|
|
||||||
return list(tasks)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
|
|
||||||
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
|
|
||||||
|
|
||||||
|
|
||||||
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
|
|
||||||
state_np = normalized_state.detach().cpu().numpy()
|
|
||||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
|
||||||
prompts = []
|
|
||||||
for task, state in zip(tasks, discretized_states, strict=True):
|
|
||||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
|
||||||
state_str = " ".join(map(str, state))
|
|
||||||
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
|
|
||||||
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
|
|
||||||
tokenized = paligemma_tokenizer()(
|
|
||||||
prompts,
|
|
||||||
padding="max_length",
|
|
||||||
padding_side="right",
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_token_len,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
tokens = tokenized["input_ids"].to(device)
|
|
||||||
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
|
|
||||||
return tokens, masks
|
|
||||||
|
|
||||||
|
|
||||||
def make_openpi_observation_from_raw(
|
|
||||||
batch: dict[str, torch.Tensor],
|
|
||||||
*,
|
|
||||||
action_dim: int,
|
|
||||||
max_token_len: int,
|
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
|
||||||
pi05: bool,
|
|
||||||
) -> OpenPIObservation:
|
|
||||||
batch_size = batch[OBS_STATE].shape[0]
|
|
||||||
device = batch[OBS_STATE].device
|
|
||||||
state = openpi_model_state_from_raw(
|
|
||||||
batch,
|
|
||||||
action_dim=action_dim,
|
|
||||||
dataset_stats=dataset_stats,
|
|
||||||
pi05=pi05,
|
|
||||||
)
|
|
||||||
|
|
||||||
tasks = _tasks_from_raw(batch, batch_size)
|
|
||||||
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
|
|
||||||
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
|
|
||||||
|
|
||||||
images = {
|
|
||||||
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
}
|
|
||||||
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
|
|
||||||
|
|
||||||
return OpenPIObservation(
|
|
||||||
state=state,
|
|
||||||
images=images,
|
|
||||||
image_masks=image_masks,
|
|
||||||
tokenized_prompt=tokens,
|
|
||||||
tokenized_prompt_mask=masks,
|
|
||||||
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
|
|
||||||
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_processor_inputs_match_lerobot(
|
|
||||||
lerobot_policy,
|
|
||||||
lerobot_batch: dict[str, torch.Tensor],
|
|
||||||
openpi_observation: OpenPIObservation,
|
|
||||||
*,
|
|
||||||
compare_state: bool,
|
|
||||||
):
|
|
||||||
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
|
|
||||||
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
|
|
||||||
|
|
||||||
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
|
|
||||||
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
|
|
||||||
torch.testing.assert_close(
|
|
||||||
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
openpi_observation.tokenized_prompt_mask,
|
|
||||||
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
|
|
||||||
rtol=0,
|
|
||||||
atol=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
|
|
||||||
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
|
|
||||||
|
|
||||||
for openpi_mask, lerobot_mask in zip(
|
|
||||||
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
|
|
||||||
):
|
|
||||||
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
|
|
||||||
|
|
||||||
if compare_state:
|
|
||||||
torch.testing.assert_close(
|
|
||||||
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
|
|
||||||
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
|
|
||||||
return safetensors.torch.load_file(cache_dir / "model.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
||||||
fixed_state_dict = dict(state_dict)
|
|
||||||
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
|
|
||||||
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
|
||||||
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
|
|
||||||
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
|
|
||||||
return fixed_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
|
|
||||||
original_sample_noise = model.sample_noise
|
|
||||||
original_sample_time = model.sample_time
|
|
||||||
|
|
||||||
def sample_noise(shape, device):
|
|
||||||
if tuple(shape) != tuple(noise.shape):
|
|
||||||
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
|
|
||||||
return noise.to(device=device)
|
|
||||||
|
|
||||||
def sample_time(batch_size, device):
|
|
||||||
if batch_size != time.shape[0]:
|
|
||||||
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
|
|
||||||
return time.to(device=device)
|
|
||||||
|
|
||||||
model.sample_noise = sample_noise
|
|
||||||
model.sample_time = sample_time
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
model.sample_noise = original_sample_noise
|
|
||||||
model.sample_time = original_sample_time
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
|
|
||||||
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
|
|
||||||
|
|
||||||
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
|
|
||||||
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
|
|
||||||
otherwise compare two different image tensors rather than two model implementations. The context manager
|
|
||||||
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
|
|
||||||
|
|
||||||
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
|
|
||||||
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
original_preprocess_observation = openpi_policy._preprocess_observation
|
|
||||||
|
|
||||||
def preprocess_observation(observation, *, train=True):
|
|
||||||
return original_preprocess_observation(observation, train=False)
|
|
||||||
|
|
||||||
openpi_policy._preprocess_observation = preprocess_observation
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
openpi_policy._preprocess_observation = original_preprocess_observation
|
|
||||||
@@ -1,207 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import time
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch._dynamo.utils import counters, guard_failures
|
|
||||||
from torch.profiler import ProfilerActivity
|
|
||||||
|
|
||||||
FORWARD_RTOL = 1e-5
|
|
||||||
FORWARD_ATOL = 5e-2
|
|
||||||
SAMPLE_RTOL = 1e-5
|
|
||||||
SAMPLE_ATOL = 1e-2
|
|
||||||
COMPILE_MODE = "max-autotune"
|
|
||||||
STEADY_STATE_WARMUPS = 3
|
|
||||||
STEADY_STATE_REPEATS = 3
|
|
||||||
|
|
||||||
|
|
||||||
def make_compile_config(config_cls, *, compile_model):
|
|
||||||
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
|
|
||||||
|
|
||||||
|
|
||||||
def counter_total(name):
|
|
||||||
return sum(counters.get(name, {}).values())
|
|
||||||
|
|
||||||
|
|
||||||
def compile_snapshot():
|
|
||||||
return {
|
|
||||||
"graph_breaks": counter_total("graph_break"),
|
|
||||||
"recompiles": counter_total("recompiles"),
|
|
||||||
"recompile_limits": counter_total("recompile_limit"),
|
|
||||||
"unique_graphs": counters["stats"].get("unique_graphs", 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def reset_compile_state():
|
|
||||||
torch._dynamo.reset()
|
|
||||||
counters.clear()
|
|
||||||
guard_failures.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def clone_cuda_graph_output(output):
|
|
||||||
if torch.is_tensor(output):
|
|
||||||
return output.clone()
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
return tuple(clone_cuda_graph_output(item) for item in output)
|
|
||||||
if isinstance(output, list):
|
|
||||||
return [clone_cuda_graph_output(item) for item in output]
|
|
||||||
if isinstance(output, dict):
|
|
||||||
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def run_model_step(fn: Callable, kwargs: dict):
|
|
||||||
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
|
|
||||||
torch.compiler.cudagraph_mark_step_begin()
|
|
||||||
return fn(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
|
|
||||||
reset_compile_state()
|
|
||||||
explanation = torch._dynamo.explain(fn)(**kwargs)
|
|
||||||
|
|
||||||
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
|
|
||||||
assert explanation.graph_break_count == 0, (
|
|
||||||
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
|
|
||||||
)
|
|
||||||
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"{label} capture: graphs={explanation.graph_count}, "
|
|
||||||
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
|
|
||||||
f"guards={len(explanation.out_guards or [])}"
|
|
||||||
)
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
|
|
||||||
eager_forward = eager_model.forward(**forward_kwargs)
|
|
||||||
compiled_forward = compiled_model.forward(**forward_kwargs)
|
|
||||||
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
|
||||||
|
|
||||||
eager_actions = eager_model.sample_actions(**sample_kwargs)
|
|
||||||
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
|
|
||||||
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
|
|
||||||
reset_compile_state()
|
|
||||||
|
|
||||||
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
|
||||||
first_snapshot = compile_snapshot()
|
|
||||||
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
|
||||||
second_snapshot = compile_snapshot()
|
|
||||||
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
|
||||||
third_snapshot = compile_snapshot()
|
|
||||||
|
|
||||||
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
|
||||||
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
|
||||||
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
|
|
||||||
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
|
|
||||||
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
|
|
||||||
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
|
|
||||||
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
|
||||||
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
|
|
||||||
)
|
|
||||||
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
|
||||||
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
|
|
||||||
)
|
|
||||||
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
|
|
||||||
|
|
||||||
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
|
|
||||||
run_warmups(eager_fn, kwargs)
|
|
||||||
run_warmups(compiled_fn, kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
eager_metrics = profile_callable(eager_fn, kwargs)
|
|
||||||
compiled_metrics = profile_callable(compiled_fn, kwargs)
|
|
||||||
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
|
|
||||||
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
|
|
||||||
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
|
|
||||||
f"{compiled_metrics['host_wall_ms']:.3f}, "
|
|
||||||
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
|
|
||||||
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
|
|
||||||
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
|
|
||||||
f"{compiled_metrics['cuda_launch_count']}, "
|
|
||||||
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
|
|
||||||
f"{compiled_metrics['profiler_event_count']}, "
|
|
||||||
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
|
|
||||||
f"{compiled_metrics['peak_mem_mib']:.1f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert eager_metrics["cuda_event_ms"] > 0
|
|
||||||
assert compiled_metrics["cuda_event_ms"] > 0
|
|
||||||
assert eager_metrics["profiler_event_count"] > 0
|
|
||||||
assert compiled_metrics["profiler_event_count"] > 0
|
|
||||||
return eager_metrics, compiled_metrics
|
|
||||||
|
|
||||||
|
|
||||||
def run_warmups(fn: Callable, kwargs: dict):
|
|
||||||
for _ in range(STEADY_STATE_WARMUPS):
|
|
||||||
run_model_step(fn, kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
def profile_callable(fn: Callable, kwargs: dict):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
host_start = time.perf_counter()
|
|
||||||
start_event.record()
|
|
||||||
for _ in range(STEADY_STATE_REPEATS):
|
|
||||||
run_model_step(fn, kwargs)
|
|
||||||
end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
|
|
||||||
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
|
|
||||||
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
|
|
||||||
|
|
||||||
with torch.profiler.profile(
|
|
||||||
activities=[ProfilerActivity.CPU],
|
|
||||||
) as profiler:
|
|
||||||
run_model_step(fn, kwargs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
key_averages = profiler.key_averages()
|
|
||||||
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
|
|
||||||
cuda_launch_count = sum(
|
|
||||||
event.count
|
|
||||||
for event in key_averages
|
|
||||||
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
|
|
||||||
)
|
|
||||||
profiler_event_count = sum(event.count for event in key_averages)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cuda_event_ms": cuda_event_ms,
|
|
||||||
"host_wall_ms": host_wall_ms,
|
|
||||||
"cpu_self_time_ms": cpu_self_time_ms,
|
|
||||||
"cuda_launch_count": cuda_launch_count,
|
|
||||||
"profiler_event_count": profiler_event_count,
|
|
||||||
"peak_mem_mib": peak_mem_mib,
|
|
||||||
}
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
|
|
||||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
|
||||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
|
|
||||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
|
||||||
IMAGE_KEYS,
|
|
||||||
assert_processor_inputs_match_lerobot,
|
|
||||||
clone_batch,
|
|
||||||
make_openpi_observation_from_raw,
|
|
||||||
openpi_model_actions_from_raw,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
|
||||||
)
|
|
||||||
|
|
||||||
DUMMY_ACTION_DIM = 32
|
|
||||||
DUMMY_STATE_DIM = 32
|
|
||||||
DUMMY_ACTION_HORIZON = 50
|
|
||||||
DUMMY_MAX_TOKEN_LEN = 200
|
|
||||||
DEVICE = torch.device("cpu")
|
|
||||||
|
|
||||||
DUMMY_DATASET_STATS = {
|
|
||||||
OBS_STATE: {
|
|
||||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
|
||||||
"std": torch.ones(DUMMY_STATE_DIM),
|
|
||||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
|
||||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
|
||||||
},
|
|
||||||
ACTION: {
|
|
||||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
|
||||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
|
||||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
|
||||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
|
||||||
},
|
|
||||||
"images": {
|
|
||||||
key: {
|
|
||||||
"mean": torch.zeros(3, 224, 224),
|
|
||||||
"std": torch.ones(3, 224, 224),
|
|
||||||
"q01": torch.zeros(3, 224, 224),
|
|
||||||
"q99": torch.ones(3, 224, 224),
|
|
||||||
}
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PI05PolicyInputAdapter(torch.nn.Module):
|
|
||||||
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
|
|
||||||
|
|
||||||
_preprocess_images = PI05Policy._preprocess_images
|
|
||||||
|
|
||||||
def __init__(self, config: PI05Config) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
|
||||||
|
|
||||||
|
|
||||||
def create_pi05_config() -> PI05Config:
|
|
||||||
config = PI05Config(device=str(DEVICE))
|
|
||||||
config.max_state_dim = DUMMY_STATE_DIM
|
|
||||||
config.max_action_dim = DUMMY_ACTION_DIM
|
|
||||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
|
||||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
|
||||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
|
||||||
config.input_features = {
|
|
||||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
|
||||||
**{
|
|
||||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config.output_features = {
|
|
||||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_data() -> dict:
|
|
||||||
batch_size = 2
|
|
||||||
prompt = "Pick up the red block and place it in the bin"
|
|
||||||
return {
|
|
||||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
|
||||||
ACTION: torch.randn(
|
|
||||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
|
||||||
),
|
|
||||||
**{
|
|
||||||
f"observation.images.{key}": torch.rand(
|
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
|
||||||
)
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
},
|
|
||||||
"task": [prompt for _ in range(batch_size)],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi05_processor_inputs_match_openpi_reference():
|
|
||||||
torch.manual_seed(0)
|
|
||||||
config = create_pi05_config()
|
|
||||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
|
||||||
|
|
||||||
raw_batch = create_dummy_data()
|
|
||||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
|
||||||
openpi_observation = make_openpi_observation_from_raw(
|
|
||||||
raw_batch,
|
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
|
||||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
|
||||||
pi05=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_processor_inputs_match_lerobot(
|
|
||||||
PI05PolicyInputAdapter(config),
|
|
||||||
lerobot_batch,
|
|
||||||
openpi_observation,
|
|
||||||
compare_state=False,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
lerobot_batch[ACTION],
|
|
||||||
openpi_model_actions_from_raw(
|
|
||||||
raw_batch,
|
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
|
||||||
pi05=True,
|
|
||||||
),
|
|
||||||
rtol=0,
|
|
||||||
atol=0,
|
|
||||||
)
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
|
|
||||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
|
||||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
|
|
||||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
|
||||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
|
||||||
IMAGE_KEYS,
|
|
||||||
assert_processor_inputs_match_lerobot,
|
|
||||||
clone_batch,
|
|
||||||
make_openpi_observation_from_raw,
|
|
||||||
openpi_model_actions_from_raw,
|
|
||||||
)
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
||||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
|
||||||
)
|
|
||||||
|
|
||||||
DUMMY_ACTION_DIM = 32
|
|
||||||
DUMMY_STATE_DIM = 32
|
|
||||||
DUMMY_ACTION_HORIZON = 50
|
|
||||||
DUMMY_MAX_TOKEN_LEN = 48
|
|
||||||
DEVICE = torch.device("cpu")
|
|
||||||
|
|
||||||
DUMMY_DATASET_STATS = {
|
|
||||||
OBS_STATE: {
|
|
||||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
|
||||||
"std": torch.ones(DUMMY_STATE_DIM),
|
|
||||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
|
||||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
|
||||||
},
|
|
||||||
ACTION: {
|
|
||||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
|
||||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
|
||||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
|
||||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
|
||||||
},
|
|
||||||
"images": {
|
|
||||||
key: {
|
|
||||||
"mean": torch.zeros(3, 224, 224),
|
|
||||||
"std": torch.ones(3, 224, 224),
|
|
||||||
"q01": torch.zeros(3, 224, 224),
|
|
||||||
"q99": torch.ones(3, 224, 224),
|
|
||||||
}
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PI0PolicyInputAdapter(torch.nn.Module):
|
|
||||||
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
|
|
||||||
|
|
||||||
_preprocess_images = PI0Policy._preprocess_images
|
|
||||||
prepare_state = PI0Policy.prepare_state
|
|
||||||
|
|
||||||
def __init__(self, config: PI0Config) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
|
||||||
|
|
||||||
|
|
||||||
def create_pi0_config() -> PI0Config:
|
|
||||||
config = PI0Config(device=str(DEVICE))
|
|
||||||
config.max_state_dim = DUMMY_STATE_DIM
|
|
||||||
config.max_action_dim = DUMMY_ACTION_DIM
|
|
||||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
|
||||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
|
||||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
|
||||||
config.input_features = {
|
|
||||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
|
||||||
**{
|
|
||||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config.output_features = {
|
|
||||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_data() -> dict:
|
|
||||||
batch_size = 2
|
|
||||||
prompt = "Pick up the red block and place it in the bin"
|
|
||||||
return {
|
|
||||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
|
||||||
ACTION: torch.randn(
|
|
||||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
|
||||||
),
|
|
||||||
**{
|
|
||||||
f"observation.images.{key}": torch.rand(
|
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
|
||||||
)
|
|
||||||
for key in IMAGE_KEYS
|
|
||||||
},
|
|
||||||
"task": [prompt for _ in range(batch_size)],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_pi0_processor_inputs_match_openpi_reference():
|
|
||||||
torch.manual_seed(0)
|
|
||||||
config = create_pi0_config()
|
|
||||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
|
||||||
|
|
||||||
raw_batch = create_dummy_data()
|
|
||||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
|
||||||
openpi_observation = make_openpi_observation_from_raw(
|
|
||||||
raw_batch,
|
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
|
||||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
|
||||||
pi05=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_processor_inputs_match_lerobot(
|
|
||||||
PI0PolicyInputAdapter(config),
|
|
||||||
lerobot_batch,
|
|
||||||
openpi_observation,
|
|
||||||
compare_state=True,
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(
|
|
||||||
lerobot_batch[ACTION],
|
|
||||||
openpi_model_actions_from_raw(
|
|
||||||
raw_batch,
|
|
||||||
action_dim=DUMMY_ACTION_DIM,
|
|
||||||
dataset_stats=DUMMY_DATASET_STATS,
|
|
||||||
pi05=False,
|
|
||||||
),
|
|
||||||
rtol=0,
|
|
||||||
atol=0,
|
|
||||||
)
|
|
||||||
22
uv.lock
generated
22
uv.lock
generated
@@ -3203,7 +3203,7 @@ requires-dist = [
|
|||||||
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
|
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
|
||||||
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
|
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
|
||||||
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
||||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
|
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
|
||||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
||||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
||||||
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
||||||
@@ -4592,7 +4592,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "placo"
|
name = "placo"
|
||||||
version = "0.9.15"
|
version = "0.9.16"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "cmeel" },
|
{ name = "cmeel" },
|
||||||
@@ -4602,16 +4602,16 @@ dependencies = [
|
|||||||
{ name = "pin" },
|
{ name = "pin" },
|
||||||
{ name = "rhoban-cmeel-jsoncpp" },
|
{ name = "rhoban-cmeel-jsoncpp" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" },
|
{ url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" },
|
{ url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" },
|
{ url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" },
|
{ url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" },
|
{ url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" },
|
{ url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" },
|
{ url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" },
|
{ url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
Reference in New Issue
Block a user