Compare commits

...

10 Commits

Author SHA1 Message Date
Maxime Ellerbach
5e2c4c22c8 Merge branch 'main' into feat/vla-jepa 2026-05-26 15:03:59 +02:00
Maximellerbach
1a536d1f71 smol fix to avoid having default CPU device when training 2026-05-26 15:03:29 +02:00
Haoming Song
5c98e80430 fix(gr00t): fix Eagle25VL model and processor crash in transformers>=5.4.0, <5.6.0 (#3652)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-26 14:04:22 +02:00
Reece O'Mahoney
f65f3f7a4a Fix policy.path in YAML configs (PR #3145 followup) (#3597)
PR #3145 added YAML support for policy.path but left two bugs:

1. extract_path_fields_from_config only deleted config_data[field] when
   no sibling overrides existed. With siblings, the dict stayed in place
   and draccus crashed decoding it as PreTrainedConfig (no 'type' key).
   Sibling overrides go into _config_yaml_overrides and are applied later
   by from_pretrained(), so the field can always be removed.

2. wrap() updated config_path_cli to the cleaned temp file path but
   never propagated it to the draccus.parse fallback branch. cli_args
   still contained --config_path=<original>, so draccus read the
   original YAML with path: still present.

Tests passed because they (a) called extract_path_fields_from_config
directly and (b) included type: alongside path: in the YAML, sidestepping
both bugs.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-26 14:01:19 +02:00
Maximellerbach
d09410882d adding instructions for different embodiement + fixing some tests 2026-05-26 11:51:54 +02:00
Pepijn
8194897994 fix(deps): cap placo below 0.9.16 and harden kinematics import (#3647)
* fix(deps): cap placo below 0.9.16 and harden kinematics import

placo 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable
on Ubuntu 24.04 (noble ships urdfdom 3.x). Importing placo on that base
crashes with:

  ImportError: liburdfdom_sensor.so.4.0: cannot open shared object file

This broke nightly Latest Deps tests (CPU and GPU) when the lockfile
upgrade picked placo 0.9.16, since lerobot.model.kinematics
unconditionally imports placo when _placo_available is true, and that
check (importlib.util.find_spec) cannot detect dlopen failures of
transitive shared libraries — so unrelated subsystems (RL actor,
gym_manipulator) became unimportable.

Two changes:

1. Pin placo to <0.9.16 in pyproject.toml + regenerate uv.lock
   (0.9.16 → 0.9.15). Short-term unblock for nightly CI until system
   urdfdom 4.x is broadly available.

2. Harden the import guard in src/lerobot/model/kinematics.py:
   wrap 'import placo' in try/except ImportError so a missing
   transitive .so no longer crashes module import. RobotKinematics
   instantiation now raises an informative ImportError citing the
   underlying dlopen failure via _raise_if_placo_unusable().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(kinematics): hoist _placo_runtime_error to module scope for mypy

Mypy walks the TYPE_CHECKING branch in which the runtime else-block is
not executed, so _placo_runtime_error was only defined at runtime and
mypy reported 'Name "_placo_runtime_error" is not defined' on the
three references inside _raise_if_placo_unusable. Declare the symbol
unconditionally at module scope with a default of None; the runtime
import-failure branch still assigns to it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* style(kinematics): drop verbose comments around placo import guard

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 12:03:07 +02:00
Haoming Song
9f437d86b6 fix(groot): align GR00TN15Config with transformers config dataclasses (#3606)
* fix(gr00t): fix gr00t config dataclass init TypeError

* fix(groot): guard strict config decorator without transformers for passing CI

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-05-22 10:31:04 +02:00
Haoming Song
b74a551d38 fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash

* fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions

* fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete

* feat(test): add compile test and benchamrk for pi0 and pi05

* feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
2026-05-22 10:29:34 +02:00
Nikodem Bartnik
c0a2e9814d fix examples (#3623)
- Fixed broken API examples in Lerobot Imitation Learning Documentation
- Teleoperation with cameras improved by adding a fixed frequency in the loop (without it the cameras feed gets very slow)
- Wrapped record example script in main() to avoid problems on Mac
- Previously teleoperation example was using SO-ARM and teleoperation with cameras was using Koch. I changed it to use SO-ARM in all of the examples.
- Added section on how to train with HF Jobs - CLI and Python examples
- Replaced lerobot-record with lerobot-rollout in policies examples
2026-05-21 22:14:07 +02:00
Khalil Meftah
bac4f61eae refactor: support custom progress parquet overlays (#3640) 2026-05-21 14:32:10 +02:00
36 changed files with 2957 additions and 905 deletions

View File

@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
```bash
lerobot-record \
--robot.type=so100_follower \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/act_policy \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_robot \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true \
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=${HF_USER}/act_policy
--task="Your task description" \ # can be skipped for ACT
--duration=60
```

View File

@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
```bash
lerobot-record \
lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
@@ -119,14 +121,12 @@ lerobot-record \
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
--duration=600
```
## License

View File

@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem58760431541",
id="my_red_robot_arm",
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
robot = SO101Follower(robot_config)
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoption id="Command">
```bash
lerobot-teleoperate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=my_awesome_leader_arm \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem5AB90687491 \
--robot.id=my_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem5AB90689011 \
--teleop.id=my_leader_arm \
--display_data=true
```
</hfoption>
@@ -122,34 +122,48 @@ lerobot-teleoperate \
<!-- prettier-ignore-start -->
```python
import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
camera_config = {
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
}
robot_config = KochFollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="my_red_robot_arm",
cameras=camera_config
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
teleop_config = KochLeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
robot = KochFollower(robot_config)
teleop_device = KochLeader(teleop_config)
init_rerun(session_name="teleoperation")
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
robot.connect()
teleop_device.connect()
TARGET_HZ = 30
TIME_PER_FRAME = 1.0 / TARGET_HZ
while True:
start_time = time.perf_counter()
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
if sleep_time > 0:
time.sleep(sleep_time)
```
<!-- prettier-ignore-end -->
@@ -202,10 +216,11 @@ lerobot-record \
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
# Create robot configuration
robot_config = SO100FollowerConfig(
id="my_awesome_follower_arm",
cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
},
port="/dev/tty.usbmodem58760434471",
)
teleop_config = SO100LeaderConfig(
id="my_awesome_leader_arm",
port="/dev/tty.usbmodem585A0077581",
)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop = SO100Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
def main():
# Create robot configuration
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
# Initialize the robot and teleoperator
robot = SO101Follower(robot_config)
teleop = SO101Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
dataset.save_episode()
episode_idx += 1
# finalize dataset
log_say("Finalizing dataset...")
dataset.finalize()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
if __name__ == "__main__":
main()
```
<!-- prettier-ignore-end -->
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
episode_idx = 0
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
robot = SO100Follower(robot_config)
robot.connect()
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Train using Hugging Face Jobs
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
To run the training use this command:
<hfoptions id="train_with_hf_jobs">
<hfoption id="Command">
```bash
hf jobs run \
--flavor a10g-small \
--timeout 4h \
--secrets HF_TOKEN \
huggingface/lerobot-gpu:latest \
-- \
python -m lerobot.scripts.lerobot_train \
--dataset.repo_id=username/dataset \
--policy.type=act \
--steps=5000 \
--batch_size=16 \
--policy.device=cuda \
--policy.repo_id=username/your_policy \
--log_freq=100
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from huggingface_hub import run_job, get_token
run_name = "act_so101_hf_jobs"
dataset_id = "username/dataset"
user_hub_id = "username"
command_args = [
"python", "-m", "lerobot.scripts.lerobot_train",
"--dataset.repo_id", dataset_id,
"--policy.type", "act",
"--steps", "5000",
"--batch_size", "16",
"--num_workers", "4",
"--policy.device", "cuda",
"--log_freq", "100",
"--save_freq", "1000",
"--save_checkpoint", "true",
"--wandb.enable", "false",
"--policy.repo_id", f"{user_hub_id}/{run_name}"
]
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
job_info = run_job(
image="huggingface/lerobot-gpu:latest",
command=command_args,
flavor="a10g-small",
timeout="4h",
secrets={"HF_TOKEN": get_token()}
)
print("\n🚀 Job successfully launched!")
print(f"🔹 Job ID: {job_info.id}")
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
For longer training sessions increase the timeout.
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:

View File

@@ -66,15 +66,16 @@ All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | -------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
---
@@ -110,6 +111,29 @@ lerobot-train \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
@@ -132,7 +156,7 @@ lerobot-eval \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5 \
--eval.batch_size=5
```
@@ -145,9 +169,19 @@ lerobot-eval \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on single-camera datasets

View File

@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing:
```bash
lerobot-record \
lerobot-rollout \
--strategy.type=base \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
# <- RTC optional, use when running on low power hardware \
# --inference.type=rtc \
# --inference.rtc.execution_horizon=10 \
# --inference.rtc.max_guidance_weight=10.0 \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_red_leader_arm \
# --display_data=true #optional use if you want to see the camera stream \
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
```

View File

@@ -15,10 +15,12 @@
# limitations under the License.
"""
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage:
python examples/dataset/create_progress_videos.py \
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
def download_episode_metadata(
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"],
allow_patterns=["meta/**", progress_file],
ignore_patterns=["*.mp4"],
)
)
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
def load_progress_data(
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / "sarm_progress.parquet"
parquet_path = local_path / progress_file
if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found")
logging.warning("%s not found", progress_file)
return None
df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
logging.info(" %s columns: %s", progress_file, list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode)
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
return None
episode_df = episode_df.sort_values("frame_index")
@@ -576,6 +585,7 @@ def process_dataset(
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
@@ -585,6 +595,8 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Path to the final output file, or None on failure.
@@ -592,7 +604,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode)
local_path = download_episode_metadata(repo_id, episode, progress_file)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -600,9 +612,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode)
progress_data = load_progress_data(local_path, episode, progress_file)
if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.")
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
return None
logging.info(" Progress frames: %d", len(progress_data))
@@ -627,7 +639,7 @@ def process_dataset(
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
@@ -658,6 +670,15 @@ def main() -> None:
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -670,6 +691,7 @@ def main() -> None:
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
progress_file=args.progress_file,
)
if result:

View File

@@ -138,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
# Common
av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]

View File

@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
else:
del config_data[field]
del config_data[field]
modified = True
if not modified:
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
if config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = draccus.parse(
config_class=argtype,
config_path=config_path_cli or config_path,
args=cli_args,
)
response = fn(cfg, *args, **kwargs)
return response

View File

@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import _placo_available, require_package
from lerobot.utils.import_utils import require_package
if TYPE_CHECKING or _placo_available:
_placo_runtime_error: ImportError | None = None
if TYPE_CHECKING:
import placo # type: ignore[import-not-found]
else:
placo = None
try:
import placo # type: ignore[import-not-found]
except ImportError as _placo_import_err:
placo = None
_placo_runtime_error = _placo_import_err
def _raise_if_placo_unusable() -> None:
if placo is None and _placo_runtime_error is not None:
raise ImportError(
f"placo is installed but failed to import: {_placo_runtime_error!s}"
) from _placo_runtime_error
class RobotKinematics:
@@ -44,6 +57,7 @@ class RobotKinematics:
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
require_package("placo", extra="placo-dep")
_raise_if_placo_unusable()
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)

View File

@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True

View File

@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(

View File

@@ -14,7 +14,7 @@
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict
action_head_cfg: dict
action_horizon: int
action_dim: int
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model

View File

@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
proc = AutoProcessor.from_pretrained(
str(cache_dir),
trust_remote_code=True,
fix_mistral_regex=False,
)
proc.tokenizer.padding_side = "left"
return proc

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -258,15 +265,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -274,13 +282,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -255,15 +262,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -271,13 +279,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -27,7 +27,12 @@ class VLAJEPAConfig(PreTrainedConfig):
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
reinit_action_head: bool = False
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."

View File

@@ -423,6 +423,7 @@ def main() -> None:
log.info(" Saving model.safetensors …")
save_safetensors(mapped_sd, save_dir / "model.safetensors")
config.device = None # don't bake in the conversion machine's device
config._save_pretrained(save_dir) # writes config.json via draccus
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)

View File

@@ -219,14 +219,9 @@ class VLAJEPAModel(nn.Module):
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = []
for i in range(b * v):
video_pixels.append(
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device)
)
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
@@ -572,11 +567,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
"""
Custom loading to enable opt reinit of action head
when loading pretrained weights with mismatched action head shapes.
"""
if not model.config.reinit_action_head:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
@@ -584,20 +576,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
mismatched: list[str] = []
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
mismatched.append(
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if mismatched:
if reinitialized:
logging.warning(
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys

View File

@@ -0,0 +1 @@
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass
@dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
def get_config(variant: str) -> Config:
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
if variant == "dummy":
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
if variant == "gemma_300m":
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
if variant == "gemma_2b":
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
raise ValueError(f"Unknown variant: {variant}")

View File

@@ -0,0 +1,300 @@
from typing import Literal
import torch
from torch import nn
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
class PaliGemmaWithExpertModel(nn.Module):
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Force enable gradient checkpointing if we're in training mode and the model supports it
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
if not self.gemma_expert.model.gradient_checkpointing:
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
self.gemma_expert.model.gradient_checkpointing = True
use_gradient_checkpointing = True
# Debug gradient checkpointing status
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
print(f"Model training mode: {self.training}")
print(
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
)
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
print(
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
)
self._debug_gc_printed = True
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
models = [self.paligemma.model.language_model, self.gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(
layer.input_layernorm, hidden_states, adarms_cond[i]
)
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
self.paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
)
# Old code removed - now using compute_layer_complete function above
# final norm
# Define final norm computation function for gradient checkpointing
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values

View File

@@ -0,0 +1,79 @@
import torch
import torch.nn.functional as F # noqa: N812
def resize_with_pad_torch(
images: torch.Tensor,
height: int,
width: int,
mode: str = "bilinear",
) -> torch.Tensor:
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
Args:
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
height: Target height
width: Target width
mode: Interpolation mode ('bilinear', 'nearest', etc.)
Returns:
Resized and padded tensor with same shape format as input
"""
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
if images.shape[-1] <= 4: # Assume channels-last format
channels_last = True
# Convert to channels-first for torch operations
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
else:
channels_last = False
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
batch_size, channels, cur_height, cur_width = images.shape
# Calculate resize ratio
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
# Resize
resized_images = F.interpolate(
images,
size=(resized_height, resized_width),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
# Calculate padding
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
mode="constant",
value=constant_value,
)
# Convert back to original format if needed
if channels_last:
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
if batch_size == 1 and images.shape[0] == 1:
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
return padded_images

View File

@@ -0,0 +1,471 @@
import copy
import logging
import math
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device):
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
class PI0Pytorch(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pi05 = config.pi05
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True] if self.pi05 else [False, False],
precision=config.dtype,
)
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
if self.pi05:
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
else:
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
torch.set_float32_matmul_precision("high")
if config.pytorch_compile_mode is not None:
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
# The upstream OpenPI module verifies a site-package Transformers patch here.
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def is_gradient_checkpointing_enabled(self):
"""Check if gradient checkpointing is enabled."""
return self.gradient_checkpointing_enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
def _preprocess_observation(self, observation, *, train=True):
"""Helper method to preprocess observation."""
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
return (
list(observation.images.values()),
list(observation.image_masks.values()),
observation.tokenized_prompt,
observation.tokenized_prompt_mask,
observation.state,
)
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# Get batch size from the first dimension of the concatenated tensors
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
if not self.pi05:
if self.state_proj.weight.dtype == torch.float32:
state = state.to(torch.float32)
# Embed state
def state_proj_func(state):
return self.state_proj(state)
state_emb = self._apply_checkpoint(state_proj_func, state)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.action_in_proj.out_features,
min_period=4e-3,
max_period=4.0,
device=timestep.device,
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
if not self.pi05:
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
# Apply MLP layers
def mlp_func(action_time_emb):
x = self.action_time_mlp_in(action_time_emb)
x = F.silu(x) # swish == silu
return self.action_time_mlp_out(x)
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
adarms_cond = None
else:
# time MLP (for adaRMS)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x) # swish == silu
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=True
)
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
# Prepare attention masks
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Apply gradient checkpointing if enabled
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
# Apply gradient checkpointing to final action projection if enabled
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad()
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = observation.state.shape[0]
if noise is None:
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
noise = self.sample_noise(actions_shape, device)
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=False
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step - use new tensor assignment instead of in-place operation
x_t = x_t + dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Prepare attention masks
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)

View File

@@ -0,0 +1,179 @@
import logging
from collections.abc import Sequence
import torch
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
logger = logging.getLogger("openpi")
# Constants moved from model.py
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
IMAGE_RESOLUTION = (224, 224)
def preprocess_observation_pytorch(
observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
):
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
This function avoids complex type annotations that can cause torch.compile issues.
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad_torch(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
image = image / 2.0 + 0.5
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = image.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
image = torch.nn.functional.interpolate(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=image.device)
grid_y = torch.linspace(-1, 1, height, device=image.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
image = torch.nn.functional.grid_sample(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = (
0.7 + torch.rand(1, device=image.device) * 0.6
) # Random factor between 0.7 and 1.3
image = image * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = (
0.6 + torch.rand(1, device=image.device) * 0.8
) # Random factor between 0.6 and 1.4
mean = image.mean(dim=[1, 2, 3], keepdim=True)
image = (image - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = (
0.5 + torch.rand(1, device=image.device) * 1.0
) # Random factor between 0.5 and 1.5
gray = image.mean(dim=-1, keepdim=True)
image = gray + (image - gray) * saturation_factor
# Clamp values to [0, 1]
image = torch.clamp(image, 0, 1)
# Back to [-1, 1]
image = image * 2.0 - 1.0
# Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
else:
out_masks[key] = observation.image_masks[key]
# Create a simple object with the required attributes instead of using the complex Observation class
class SimpleProcessedObservation:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return SimpleProcessedObservation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)

View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi05_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
benchmark_runtime(
eager_model.sample_actions,
compiled_model.sample_actions,
sample_kwargs,
"pi05.sample_actions",
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()

View File

@@ -14,52 +14,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
import gc
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
from lerobot.configs import PreTrainedConfig # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
assert_processor_inputs_match_lerobot,
clone_batch,
deterministic_openpi_forward_preprocess,
fix_reference_state_dict,
fixed_flow_sampling,
load_openpi_reference_state_dict,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
)
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
from lerobot.types import PolicyAction # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COMPILE_MODE = "default"
FORWARD_RTOL = 1e-4
FORWARD_ATOL = 1e-4
SAMPLE_RTOL = 1e-2
SAMPLE_ATOL = 5e-3
DUMMY_DATASET_STATS = {
"observation.state": {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
"action": {
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
@@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = {
}
@pytest.fixture(autouse=True)
def cleanup_cuda_after_test():
yield
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class PI05BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
@@ -96,341 +109,163 @@ class PI05BaseOriginalConfig:
precision: str = "float32"
pi05: bool = True
dtype: str = "float32"
pytorch_compile_mode: str | None = None
def instantiate_lerobot_pi05(
from_pretrained: bool = False,
) -> tuple[
PI05Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
else:
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI05Policy(config)
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
config.device = str(DEVICE)
config.dtype = "float32"
config.compile_model = compile_model
config.compile_mode = COMPILE_MODE
config.gradient_checkpointing = gradient_checkpointing
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi05_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
policy.config.device = str(DEVICE)
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
return policy, preprocessor
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
config = PI05BaseOriginalConfig()
policy = PI0Pytorch(config)
def instantiate_original_pi05():
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名LeRobot 的
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
# loader所以这里手动复用同一份 tensor避免把权重别名差异误判成模型差异。
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
assert missing_keys == []
assert unexpected_keys == []
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
torch.manual_seed(0)
raw_batch = create_dummy_data()
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
openpi_actions = openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
lerobot_pi05,
lerobot_batch,
openpi_observation,
compare_state=False,
)
batch_size = raw_batch[OBS_STATE].shape[0]
noise = torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=DEVICE,
)
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
return lerobot_batch, openpi_observation, openpi_actions, noise, time
class PI05Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description (pi05 processor handles all text formatting)
tasks = batch.get("task", ["Pick up the object"] * batch_size)
if isinstance(tasks, str):
tasks = [tasks] * batch_size
elif len(tasks) == 1:
tasks = tasks * batch_size
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
state = batch["observation.state"]
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
from lerobot.policies.pi05.modeling_pi05 import pad_vector
state = pad_vector(state, DUMMY_STATE_DIM)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create pi05-formatted prompts that include state information
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
compile_model=compile_model,
gradient_checkpointing=gradient_checkpointing,
)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
if gradient_checkpointing:
lerobot_pi05.train()
else:
lerobot_pi05.eval()
original_pi05.eval()
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
with deterministic_openpi_forward_preprocess(original_pi05):
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
openpi_loss = openpi_losses.mean(dim=(1, 2))
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi05_original_vs_lerobot():
"""Test PI05 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi05(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi05.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
original_pi05.eval()
with torch.no_grad():
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
openpi_actions = original_pi05.sample_actions(
device=DEVICE,
observation=openpi_observation,
noise=noise,
num_steps=10,
)
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
def test_pi05_forward_matches_openpi():
assert_forward_matches()
def test_pi05_sample_actions_match_openpi():
assert_sample_actions_match_openpi()
def test_pi05_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(gradient_checkpointing=True)
def test_pi05_compile_forward_matches_openpi():
assert_forward_matches(compile_model=True)
def test_pi05_compile_sample_actions_match_openpi():
assert_sample_actions_match_openpi(compile_model=True)
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(compile_model=True, gradient_checkpointing=True)

View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi0 import PI0Config # noqa: E402
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
"state": torch.randn(1, config.max_state_dim, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi0_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
benchmark_runtime(
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()

View File

@@ -14,51 +14,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
import gc
import os
from copy import deepcopy
from typing import Any
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
from lerobot.configs import PreTrainedConfig # noqa: E402
from lerobot.policies.pi0 import PI0Policy # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
assert_processor_inputs_match_lerobot,
clone_batch,
deterministic_openpi_forward_preprocess,
fix_reference_state_dict,
fixed_flow_sampling,
load_openpi_reference_state_dict,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
)
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
from lerobot.types import PolicyAction # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_MAX_TOKEN_LEN = 48
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COMPILE_MODE = "default"
FORWARD_RTOL = 1e-4
FORWARD_ATOL = 1e-4
SAMPLE_RTOL = 1e-2
SAMPLE_ATOL = 5e-3
DUMMY_DATASET_STATS = {
"observation.state": {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
"action": {
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
@@ -87,6 +92,15 @@ DUMMY_DATASET_STATS = {
}
@pytest.fixture(autouse=True)
def cleanup_cuda_after_test():
yield
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class PI0BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
@@ -95,333 +109,156 @@ class PI0BaseOriginalConfig:
precision: str = "float32"
pi05: bool = False
dtype: str = "float32"
pytorch_compile_mode: str | None = None
def instantiate_lerobot_pi0(
from_pretrained: bool = False,
) -> tuple[
PI0Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
else:
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI0Policy(config)
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
config.device = str(DEVICE)
config.dtype = "float32"
config.compile_model = compile_model
config.compile_mode = COMPILE_MODE
config.gradient_checkpointing = gradient_checkpointing
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi0_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
policy.config.device = str(DEVICE)
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
return policy, preprocessor
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
config = PI0BaseOriginalConfig()
policy = PI0Pytorch(config)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
def instantiate_original_pi0():
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
assert missing_keys == []
assert unexpected_keys == []
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
torch.manual_seed(0)
raw_batch = create_dummy_data()
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
)
openpi_actions = openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
)
assert_processor_inputs_match_lerobot(
lerobot_pi0,
lerobot_batch,
openpi_observation,
compare_state=True,
)
batch_size = raw_batch[OBS_STATE].shape[0]
noise = torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=DEVICE,
)
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
return lerobot_batch, openpi_observation, openpi_actions, noise, time
class PI0Observation:
"""Observation class that matches the original OpenPI format."""
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
compile_model=compile_model,
gradient_checkpointing=gradient_checkpointing,
)
original_pi0 = instantiate_original_pi0()
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
lerobot_pi0,
lerobot_preprocessor,
)
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
# Single string: add newline if not present, then convert to list
if not tasks.endswith("\n"):
tasks = f"{tasks}\n"
tasks = [tasks]
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
# List of strings: add newline to each if not present
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
if len(tasks) == 1:
# Expand to batch size
tasks = tasks * batch_size
if len(tasks) != batch_size:
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
# If task is neither string nor list of strings, leave unchanged
if gradient_checkpointing:
lerobot_pi0.train()
else:
# Default task if not provided
tasks = ["Pick up the object\n"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi0(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
lerobot_pi0.eval()
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
with deterministic_openpi_forward_preprocess(original_pi0):
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
openpi_loss = openpi_losses.mean(dim=(1, 2))
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
original_pi0 = instantiate_original_pi0()
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
lerobot_pi0,
lerobot_preprocessor,
)
print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
original_pi0.eval()
with torch.no_grad():
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
openpi_actions = original_pi0.sample_actions(
device=DEVICE,
observation=openpi_observation,
noise=noise,
num_steps=10,
)
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
def test_pi0_forward_matches_openpi():
assert_forward_matches()
def test_pi0_sample_actions_match_openpi():
assert_sample_actions_match_openpi()
def test_pi0_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(gradient_checkpointing=True)
def test_pi0_compile_forward_matches_openpi():
assert_forward_matches(compile_model=True)
def test_pi0_compile_sample_actions_match_openpi():
assert_sample_actions_match_openpi(compile_model=True)
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(compile_model=True, gradient_checkpointing=True)

View File

@@ -0,0 +1 @@
"""Utilities shared by PI0/PI05 policy tests."""

View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
import numpy as np
import safetensors.torch
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
@dataclass
class OpenPIObservation:
state: torch.Tensor
images: dict[str, torch.Tensor]
image_masks: dict[str, torch.Tensor]
tokenized_prompt: torch.Tensor
tokenized_prompt_mask: torch.Tensor
token_ar_mask: torch.Tensor
token_loss_mask: torch.Tensor
@lru_cache(maxsize=1)
def paligemma_tokenizer():
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
def clone_batch(batch: dict) -> dict:
return {
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
}
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
if tensor.shape[-1] > target_dim:
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
return (tensor - mean) / (std + 1e-8)
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
return 2.0 * (tensor - q01) / denom - 1.0
def openpi_model_state_from_raw(
batch: dict[str, torch.Tensor],
*,
action_dim: int,
dataset_stats: dict[str, dict[str, torch.Tensor]],
pi05: bool,
) -> torch.Tensor:
state = batch[OBS_STATE].to(dtype=torch.float32)
if pi05:
state = quantile_normalize(state, dataset_stats[OBS_STATE])
else:
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
return pad_last_dim(state, action_dim)
def openpi_model_actions_from_raw(
batch: dict[str, torch.Tensor],
*,
action_dim: int,
dataset_stats: dict[str, dict[str, torch.Tensor]],
pi05: bool,
) -> torch.Tensor:
actions = batch[ACTION].to(dtype=torch.float32)
if pi05:
actions = quantile_normalize(actions, dataset_stats[ACTION])
else:
actions = mean_std_normalize(actions, dataset_stats[ACTION])
return pad_last_dim(actions, action_dim)
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
tasks = batch.get("task")
if tasks is None:
raise ValueError("The parity batch must include a task prompt.")
if isinstance(tasks, str):
return [tasks] * batch_size
if len(tasks) == 1:
return [tasks[0]] * batch_size
if len(tasks) != batch_size:
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
return list(tasks)
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
state_np = normalized_state.detach().cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
prompts = []
for task, state in zip(tasks, discretized_states, strict=True):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, state))
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
return prompts
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
tokenized = paligemma_tokenizer()(
prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=max_token_len,
return_tensors="pt",
)
tokens = tokenized["input_ids"].to(device)
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
return tokens, masks
def make_openpi_observation_from_raw(
batch: dict[str, torch.Tensor],
*,
action_dim: int,
max_token_len: int,
dataset_stats: dict[str, dict[str, torch.Tensor]],
pi05: bool,
) -> OpenPIObservation:
batch_size = batch[OBS_STATE].shape[0]
device = batch[OBS_STATE].device
state = openpi_model_state_from_raw(
batch,
action_dim=action_dim,
dataset_stats=dataset_stats,
pi05=pi05,
)
tasks = _tasks_from_raw(batch, batch_size)
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
images = {
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
for key in IMAGE_KEYS
}
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
return OpenPIObservation(
state=state,
images=images,
image_masks=image_masks,
tokenized_prompt=tokens,
tokenized_prompt_mask=masks,
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
)
def assert_processor_inputs_match_lerobot(
lerobot_policy,
lerobot_batch: dict[str, torch.Tensor],
openpi_observation: OpenPIObservation,
*,
compare_state: bool,
):
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
torch.testing.assert_close(
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
)
torch.testing.assert_close(
openpi_observation.tokenized_prompt_mask,
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
rtol=0,
atol=0,
)
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
for openpi_mask, lerobot_mask in zip(
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
):
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
if compare_state:
torch.testing.assert_close(
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
)
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
return safetensors.torch.load_file(cache_dir / "model.safetensors")
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
fixed_state_dict = dict(state_dict)
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
return fixed_state_dict
@contextmanager
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
original_sample_noise = model.sample_noise
original_sample_time = model.sample_time
def sample_noise(shape, device):
if tuple(shape) != tuple(noise.shape):
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
return noise.to(device=device)
def sample_time(batch_size, device):
if batch_size != time.shape[0]:
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
return time.to(device=device)
model.sample_noise = sample_noise
model.sample_time = sample_time
try:
yield
finally:
model.sample_noise = original_sample_noise
model.sample_time = original_sample_time
@contextmanager
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
otherwise compare two different image tensors rather than two model implementations. The context manager
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
"""
original_preprocess_observation = openpi_policy._preprocess_observation
def preprocess_observation(observation, *, train=True):
return original_preprocess_observation(observation, train=False)
openpi_policy._preprocess_observation = preprocess_observation
try:
yield
finally:
openpi_policy._preprocess_observation = original_preprocess_observation

View File

@@ -0,0 +1,207 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections.abc import Callable
import torch
from torch._dynamo.utils import counters, guard_failures
from torch.profiler import ProfilerActivity
FORWARD_RTOL = 1e-5
FORWARD_ATOL = 5e-2
SAMPLE_RTOL = 1e-5
SAMPLE_ATOL = 1e-2
COMPILE_MODE = "max-autotune"
STEADY_STATE_WARMUPS = 3
STEADY_STATE_REPEATS = 3
def make_compile_config(config_cls, *, compile_model):
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
def counter_total(name):
return sum(counters.get(name, {}).values())
def compile_snapshot():
return {
"graph_breaks": counter_total("graph_break"),
"recompiles": counter_total("recompiles"),
"recompile_limits": counter_total("recompile_limit"),
"unique_graphs": counters["stats"].get("unique_graphs", 0),
}
def reset_compile_state():
torch._dynamo.reset()
counters.clear()
guard_failures.clear()
def clone_cuda_graph_output(output):
if torch.is_tensor(output):
return output.clone()
if isinstance(output, tuple):
return tuple(clone_cuda_graph_output(item) for item in output)
if isinstance(output, list):
return [clone_cuda_graph_output(item) for item in output]
if isinstance(output, dict):
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
return output
def run_model_step(fn: Callable, kwargs: dict):
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
torch.compiler.cudagraph_mark_step_begin()
return fn(**kwargs)
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
reset_compile_state()
explanation = torch._dynamo.explain(fn)(**kwargs)
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
assert explanation.graph_break_count == 0, (
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
)
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
print(
f"{label} capture: graphs={explanation.graph_count}, "
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
f"guards={len(explanation.out_guards or [])}"
)
return explanation
@torch.no_grad()
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
eager_forward = eager_model.forward(**forward_kwargs)
compiled_forward = compiled_model.forward(**forward_kwargs)
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
eager_actions = eager_model.sample_actions(**sample_kwargs)
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
@torch.no_grad()
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
reset_compile_state()
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
first_snapshot = compile_snapshot()
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
second_snapshot = compile_snapshot()
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
third_snapshot = compile_snapshot()
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
)
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
)
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
@torch.no_grad()
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
run_warmups(eager_fn, kwargs)
run_warmups(compiled_fn, kwargs)
torch.cuda.synchronize()
eager_metrics = profile_callable(eager_fn, kwargs)
compiled_metrics = profile_callable(compiled_fn, kwargs)
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
print(
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
f"{compiled_metrics['host_wall_ms']:.3f}, "
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
f"{compiled_metrics['cuda_launch_count']}, "
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
f"{compiled_metrics['profiler_event_count']}, "
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
f"{compiled_metrics['peak_mem_mib']:.1f}"
)
assert eager_metrics["cuda_event_ms"] > 0
assert compiled_metrics["cuda_event_ms"] > 0
assert eager_metrics["profiler_event_count"] > 0
assert compiled_metrics["profiler_event_count"] > 0
return eager_metrics, compiled_metrics
def run_warmups(fn: Callable, kwargs: dict):
for _ in range(STEADY_STATE_WARMUPS):
run_model_step(fn, kwargs)
torch.cuda.synchronize()
def profile_callable(fn: Callable, kwargs: dict):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
host_start = time.perf_counter()
start_event.record()
for _ in range(STEADY_STATE_REPEATS):
run_model_step(fn, kwargs)
end_event.record()
torch.cuda.synchronize()
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
with torch.profiler.profile(
activities=[ProfilerActivity.CPU],
) as profiler:
run_model_step(fn, kwargs)
torch.cuda.synchronize()
key_averages = profiler.key_averages()
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
cuda_launch_count = sum(
event.count
for event in key_averages
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
)
profiler_event_count = sum(event.count for event in key_averages)
return {
"cuda_event_ms": cuda_event_ms,
"host_wall_ms": host_wall_ms,
"cpu_self_time_ms": cpu_self_time_ms,
"cuda_launch_count": cuda_launch_count,
"profiler_event_count": profiler_event_count,
"peak_mem_mib": peak_mem_mib,
}

View File

@@ -169,6 +169,7 @@ class _FakeQwenBackbone(nn.Module):
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
self.model(input_ids) # call through so the forward hook on layers[-1] fires
return SimpleNamespace(hidden_states=[hidden])
@@ -241,9 +242,13 @@ class _FakeVideoEncoder(nn.Module):
class _FakeVideoProcessor:
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
if isinstance(videos, list):
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
else:
pixel_values = torch.as_tensor(videos).unsqueeze(0)
return {"pixel_values_videos": pixel_values}
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
IMAGE_KEYS,
assert_processor_inputs_match_lerobot,
clone_batch,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
)
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = torch.device("cpu")
DUMMY_DATASET_STATS = {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
key: {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
}
for key in IMAGE_KEYS
},
}
class PI05PolicyInputAdapter(torch.nn.Module):
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
_preprocess_images = PI05Policy._preprocess_images
def __init__(self, config: PI05Config) -> None:
super().__init__()
self.config = config
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
def create_pi05_config() -> PI05Config:
config = PI05Config(device=str(DEVICE))
config.max_state_dim = DUMMY_STATE_DIM
config.max_action_dim = DUMMY_ACTION_DIM
config.chunk_size = DUMMY_ACTION_HORIZON
config.n_action_steps = DUMMY_ACTION_HORIZON
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
**{
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
for key in IMAGE_KEYS
},
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
}
return config
def create_dummy_data() -> dict:
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
**{
f"observation.images.{key}": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
)
for key in IMAGE_KEYS
},
"task": [prompt for _ in range(batch_size)],
}
def test_pi05_processor_inputs_match_openpi_reference():
torch.manual_seed(0)
config = create_pi05_config()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
raw_batch = create_dummy_data()
lerobot_batch = preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
PI05PolicyInputAdapter(config),
lerobot_batch,
openpi_observation,
compare_state=False,
)
torch.testing.assert_close(
lerobot_batch[ACTION],
openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
),
rtol=0,
atol=0,
)

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi0 import PI0Policy # noqa: E402
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
IMAGE_KEYS,
assert_processor_inputs_match_lerobot,
clone_batch,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
)
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48
DEVICE = torch.device("cpu")
DUMMY_DATASET_STATS = {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
key: {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
}
for key in IMAGE_KEYS
},
}
class PI0PolicyInputAdapter(torch.nn.Module):
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
_preprocess_images = PI0Policy._preprocess_images
prepare_state = PI0Policy.prepare_state
def __init__(self, config: PI0Config) -> None:
super().__init__()
self.config = config
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
def create_pi0_config() -> PI0Config:
config = PI0Config(device=str(DEVICE))
config.max_state_dim = DUMMY_STATE_DIM
config.max_action_dim = DUMMY_ACTION_DIM
config.chunk_size = DUMMY_ACTION_HORIZON
config.n_action_steps = DUMMY_ACTION_HORIZON
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
**{
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
for key in IMAGE_KEYS
},
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
}
return config
def create_dummy_data() -> dict:
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
**{
f"observation.images.{key}": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
)
for key in IMAGE_KEYS
},
"task": [prompt for _ in range(batch_size)],
}
def test_pi0_processor_inputs_match_openpi_reference():
torch.manual_seed(0)
config = create_pi0_config()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
raw_batch = create_dummy_data()
lerobot_batch = preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
)
assert_processor_inputs_match_lerobot(
PI0PolicyInputAdapter(config),
lerobot_batch,
openpi_observation,
compare_state=True,
)
torch.testing.assert_close(
lerobot_batch[ACTION],
openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
),
rtol=0,
atol=0,
)

View File

@@ -1,10 +1,14 @@
"""Tests for policy.path support in YAML config files (issue #2957)."""
import json
import sys
import tempfile
from dataclasses import dataclass, field
from unittest.mock import patch
import yaml
from lerobot.configs import parser
from lerobot.configs.parser import (
_config_path_args,
_config_yaml_overrides,
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
def test_extract_path_fields_from_yaml():
"""Test that policy.path is extracted from a YAML config and removed."""
"""Test that policy.path is extracted from a YAML config and the policy block
is removed entirely (siblings are captured separately as cli_overrides)."""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
# Path should be extracted and stored
assert _config_path_args["policy"] == "lerobot/smolvla_base"
# Cleaned config should not have the path field
# Cleaned config should not have the policy block at all -- draccus must not
# try to decode it as PreTrainedConfig; the actual config comes from
# from_pretrained(path) with the captured overrides applied on top.
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f)
assert "path" not in cleaned["policy"]
assert cleaned["policy"]["type"] == "smolvla"
assert cleaned["policy"]["push_to_hub"] is False
assert "policy" not in cleaned
# Original dataset should be untouched
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
# Sibling overrides (excluding type/path) captured for from_pretrained.
overrides = get_yaml_overrides("policy")
assert any("push_to_hub=false" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_path_fields_from_json():
"""Test that policy.path is extracted from a JSON config."""
"""Test that policy.path is extracted from a JSON config and the policy
block is removed entirely."""
config = {
"policy": {"type": "act", "path": "some/local/path"},
}
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
assert _config_path_args["policy"] == "some/local/path"
with open(cleaned_path) as f:
cleaned = json.load(f)
assert "path" not in cleaned["policy"]
assert "policy" not in cleaned
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_no_path_returns_original():
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
args = _flatten_to_cli_args(d)
assert "--optimizer.use_warmup=true" in args
assert "--optimizer.lr=0.01" in args
def test_extract_removes_field_with_siblings_and_no_type():
"""Regression: when policy.path has siblings but no type:, the entire policy
block must still be removed from the cleaned config. Otherwise draccus tries
to decode the leftover dict as PreTrainedConfig and crashes on the missing
type discriminator.
"""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {
"path": "lerobot/smolvla_base",
"n_action_steps": 10,
"dtype": "bfloat16",
},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f) or {}
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
assert _config_path_args["policy"] == "lerobot/smolvla_base"
overrides = get_yaml_overrides("policy")
assert any("n_action_steps=10" in o for o in overrides)
assert any("dtype=bfloat16" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
@dataclass
class _DummyNested:
foo: int = 0
@dataclass
class _DummyConfig:
nested: _DummyNested = field(default_factory=_DummyNested)
other: str = "default"
@classmethod
def __get_path_fields__(cls):
return ["nested"]
def test_wrap_uses_cleaned_config_for_draccus_parse():
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
file but must propagate that to the draccus.parse fallback branch. Without
the fix, cli_args still contains --config_path=<original> and draccus reads
the original YAML with `path:` still in it, crashing on the unknown field.
"""
config = {
"nested": {"path": "some/checkpoint", "foo": 42},
"other": "set-via-yaml",
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
captured: dict = {}
@parser.wrap()
def main(cfg: _DummyConfig) -> _DummyConfig:
captured["cfg"] = cfg
return cfg
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
main()
assert captured["cfg"].other == "set-via-yaml"
assert _config_path_args["nested"] == "some/checkpoint"
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
# class (a real PreTrainedConfig would now load the checkpoint and apply
# the captured yaml_overrides via from_pretrained()).
assert captured["cfg"].nested.foo == 0
_config_path_args.clear()
_config_yaml_overrides.clear()

22
uv.lock generated
View File

@@ -3212,7 +3212,7 @@ requires-dist = [
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
@@ -4601,7 +4601,7 @@ wheels = [
[[package]]
name = "placo"
version = "0.9.16"
version = "0.9.15"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cmeel" },
@@ -4611,16 +4611,16 @@ dependencies = [
{ name = "pin" },
{ name = "rhoban-cmeel-jsoncpp" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" }
sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" },
{ url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" },
{ url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" },
{ url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" },
{ url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" },
{ url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" },
{ url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" },
{ url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" },
{ url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" },
{ url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" },
{ url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" },
{ url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" },
{ url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" },
{ url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" },
{ url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" },
]
[[package]]