mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
9 Commits
docs/model
...
docs/add-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71848ebb2e | ||
|
|
5c98e80430 | ||
|
|
f65f3f7a4a | ||
|
|
8194897994 | ||
|
|
9f437d86b6 | ||
|
|
b74a551d38 | ||
|
|
c0a2e9814d | ||
|
|
bac4f61eae | ||
|
|
f4b834844e |
@@ -9,6 +9,8 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
|
||||
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,14 +121,12 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
42
docs/source/lelab.mdx
Normal file
42
docs/source/lelab.mdx
Normal file
@@ -0,0 +1,42 @@
|
||||
# LeLab - LeRobot Guide
|
||||
Graphical user interfaces are the easiest to use for beginners because it's easy to just click everything without remembering the proper commands. That's why we built LeLab which is a GUI built on top of the LeRobot library. With this app you will be able to add robots, collect datasets, train and deploy models.
|
||||
|
||||
### Installation
|
||||
To install lerobot you can simply copy the following command and paste into your terminal. For it to work you need to have `uv` installed, [here is how to do it.](https://docs.astral.sh/uv/getting-started/installation/)
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
Once installed you will be able to run lelab anytime you want with `lelab` command from your terminal (above command has it included at the end so it will run it right after installation).
|
||||
|
||||
### Adding robots
|
||||
|
||||
##### Calibration
|
||||
You will need to select the proper arm type (leader or follower) and calibrate each arm as shown in the video available inside LeLab. Make sure that all joints are in the middle position when starting the calibration.
|
||||
|
||||
##### Adding cameras
|
||||
At the bottom of the add robot page you can also add the cameras and name them accordingly.
|
||||
|
||||
### Teleoperation
|
||||
Once the robots have been configured you can go back and click the teleoperation button. You will see the 3D visualization of the arm and will be able to control the follower with the leader. If something doesn't work there, remove and add your robot again following the steps described in LeLab.
|
||||
|
||||
### Recording a dataset
|
||||
Type a new name for your dataset and press on the plus button. You will need to provide:
|
||||
- Task description, for example "put the cube in a container"
|
||||
- Number of episodes that you want to record, at least 30 recommended
|
||||
- Episode and reset durations. These are max durations and can be shortened while recording with a spacebar press.
|
||||
- If you configured your cameras earlier you don't need to do that again.
|
||||
|
||||
Press start recording, wait for it to load, perform the task with confident movements but don't rush. Once the task is finished and you moved your robot to the initial position press the spacebar. You will have time to reset the environment for example grab the cube from the container and placing it on the desk again. Once ready press the spacebar and record the next episode. Repeat until all the episodes are recorded.
|
||||
|
||||
### Training a model
|
||||
This is the most powerful function with LeLab! You can easily train models locally on your own computer but also with [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) which gives you easy access to very powerful GPUs with clear pricing.
|
||||
|
||||
> [!TIP]
|
||||
> To use HF Jobs make sure that you are logged in to HF, you can do that by running `hf auth login` in the terminal.
|
||||
|
||||
In the training tab select if you want to train locally or specific HF hardware you want to use. You will also need to provide the dataset that will be used for training. Your own datasets will be listed in a dropdown list, you can also use other datasets by providing its id. Set the policy you want to train, batch size and number of steps. For guide on choosing hardware and batch size check out our [Compute HW Guide for LeRobot Training.](hardware_guide.mdx)
|
||||
|
||||
Once you start training the progress will be visualized inside LeLab. Checkpoints will be saved as well.
|
||||
### Running the model on a robot
|
||||
In the main view of the LeLab under jobs you will see all the models that you trained. To run the policy on the robot just click the green run button and press start inference. After loading the policy the robot should start solving the task that it learned during training.
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -138,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
else:
|
||||
del config_data[field]
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -44,6 +57,7 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -43,6 +43,7 @@ from .tables import (
|
||||
CAN_CMD_SET_ZERO,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
HANDSHAKE_TIMEOUT_S,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||
def _query_status_via_clear_fault(
|
||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||
) -> tuple[bool, can.Message | None]:
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||
|
||||
def _recv_status_via_clear_fault(
|
||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
faulted_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
elif has_fault:
|
||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
|
||||
return responses
|
||||
|
||||
def _recv_all_messages_until_quiet(
|
||||
self,
|
||||
*,
|
||||
timeout: float = RUNNING_TIMEOUT,
|
||||
max_messages: int = 4096,
|
||||
) -> list[can.Message]:
|
||||
"""
|
||||
Receive frames until the bus goes quiet.
|
||||
|
||||
Args:
|
||||
timeout: Poll timeout used for each recv() call. Collection stops
|
||||
when one recv() times out (quiet gap).
|
||||
max_messages: Safety cap to prevent unbounded loops.
|
||||
"""
|
||||
out: list[can.Message] = []
|
||||
max_messages = max(1, max_messages)
|
||||
timeout = max(0.0, timeout)
|
||||
|
||||
try:
|
||||
while len(out) < max_messages:
|
||||
msg = self._bus().recv(timeout=timeout)
|
||||
if msg is None:
|
||||
break
|
||||
out.append(msg)
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||
|
||||
return out
|
||||
|
||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||
"""
|
||||
Decode all received feedback frames and update cached motor states.
|
||||
|
||||
Returns:
|
||||
Set of payload recv_ids that were successfully mapped to motors.
|
||||
"""
|
||||
processed_recv_ids: set[int] = set()
|
||||
for msg in messages:
|
||||
if len(msg.data) < 1:
|
||||
logger.debug(
|
||||
f"Dropping short CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
recv_id = int(msg.data[0])
|
||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||
if motor_name is None:
|
||||
logger.debug(
|
||||
f"Unmapped CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
self._process_response(motor_name, msg)
|
||||
processed_recv_ids.add(recv_id)
|
||||
|
||||
return processed_recv_ids
|
||||
|
||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||
"""
|
||||
Drain pending RX frames from the CAN interface.
|
||||
|
||||
This is used by higher-level controllers to drop stale feedback before issuing
|
||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||
It should also be called once when a controller instance is created/connected,
|
||||
to clear residual frames left on the interface from previous sessions.
|
||||
"""
|
||||
drained = 0
|
||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||
max_messages = max(1, max_messages)
|
||||
try:
|
||||
while drained < max_messages:
|
||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||
if msg is None:
|
||||
break
|
||||
drained += 1
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||
return drained
|
||||
|
||||
def _speed_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||
# This avoids dropping useful frames when responses from different motors interleave.
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
try:
|
||||
self._decode_motor_state(msg.data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to decode response from {motor} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||
)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the state cache."""
|
||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._bus().send(msg)
|
||||
updated_motors.append(motor)
|
||||
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||
|
||||
for response in responses.values():
|
||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||
if payload_motor_name is not None:
|
||||
self._process_response(payload_motor_name, response)
|
||||
else:
|
||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||
self._decode_motor_state(response.data)
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
for motor in updated_motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if recv_id not in responses:
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
|
||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
RUNNING_TIMEOUT = 0.003
|
||||
HANDSHAKE_TIMEOUT_S = 0.05
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
|
||||
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc = AutoProcessor.from_pretrained(
|
||||
str(cache_dir),
|
||||
trust_remote_code=True,
|
||||
fix_mistral_regex=False,
|
||||
)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -255,15 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -271,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -248,13 +248,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
def generate_model_card(
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
base_model_mapping = {
|
||||
"smolvla": "lerobot/smolvla_base",
|
||||
"pi0": "lerobot/pi0_base",
|
||||
"pi05": "lerobot/pi05_base",
|
||||
"pi0_fast": "lerobot/pi0fast-base",
|
||||
"xvla": "lerobot/xvla-base",
|
||||
}
|
||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
||||
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
@@ -263,7 +257,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
base_model=base_model_mapping(model_type, None),
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
template_card = (
|
||||
|
||||
@@ -73,17 +73,14 @@ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
||||
### Evaluate the policy/run inference
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
--episodes=10
|
||||
```
|
||||
|
||||
If you want to record a dataset while testing the policy use `--dataset.repo_id=<hf_user>/eval_dataset_name` it is important to use the prefix **eval\_**. For the policy path use the policy from the Hugging Face Hub or a local one. Skipping duration will make the policy run indefinitely.
|
||||
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
|
||||
|
||||
---
|
||||
|
||||
|
||||
1
tests/policies/pi0_pi05/openpi_pytorch/__init__.py
Normal file
1
tests/policies/pi0_pi05/openpi_pytorch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""
|
||||
22
tests/policies/pi0_pi05/openpi_pytorch/gemma.py
Normal file
22
tests/policies/pi0_pi05/openpi_pytorch/gemma.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
|
||||
|
||||
def get_config(variant: str) -> Config:
|
||||
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
||||
if variant == "dummy":
|
||||
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
||||
if variant == "gemma_300m":
|
||||
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
if variant == "gemma_2b":
|
||||
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
300
tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py
Normal file
300
tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
|
||||
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
||||
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
if not self.gemma_expert.model.gradient_checkpointing:
|
||||
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
||||
self.gemma_expert.model.gradient_checkpointing = True
|
||||
use_gradient_checkpointing = True
|
||||
|
||||
# Debug gradient checkpointing status
|
||||
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
||||
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
||||
print(f"Model training mode: {self.training}")
|
||||
print(
|
||||
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
||||
)
|
||||
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
print(
|
||||
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
||||
)
|
||||
self._debug_gc_printed = True
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(
|
||||
layer.input_layernorm, hidden_states, adarms_cond[i]
|
||||
)
|
||||
gates.append(gate)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
self.paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
|
||||
return outputs_embeds
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
# Old code removed - now using compute_layer_complete function above
|
||||
|
||||
# final norm
|
||||
# Define final norm computation function for gradient checkpointing
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms,
|
||||
inputs_embeds,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
79
tests/policies/pi0_pi05/openpi_pytorch/image_tools.py
Normal file
79
tests/policies/pi0_pi05/openpi_pytorch/image_tools.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def resize_with_pad_torch(
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
mode: str = "bilinear",
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
||||
|
||||
Args:
|
||||
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
||||
height: Target height
|
||||
width: Target width
|
||||
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
||||
|
||||
Returns:
|
||||
Resized and padded tensor with same shape format as input
|
||||
"""
|
||||
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
||||
if images.shape[-1] <= 4: # Assume channels-last format
|
||||
channels_last = True
|
||||
# Convert to channels-first for torch operations
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
||||
else:
|
||||
channels_last = False
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
|
||||
batch_size, channels, cur_height, cur_width = images.shape
|
||||
|
||||
# Calculate resize ratio
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
# Resize
|
||||
resized_images = F.interpolate(
|
||||
images,
|
||||
size=(resized_height, resized_width),
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
)
|
||||
|
||||
# Handle dtype-specific clipping
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
# Calculate padding
|
||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
||||
pad_h1 = pad_h0 + remainder_h
|
||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
mode="constant",
|
||||
value=constant_value,
|
||||
)
|
||||
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
471
tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py
Normal file
471
tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
class PI0Pytorch(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.pi05 = config.pi05
|
||||
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True] if self.pi05 else [False, False],
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
||||
|
||||
if self.pi05:
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
else:
|
||||
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if config.pytorch_compile_mode is not None:
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
# The upstream OpenPI module verifies a site-package Transformers patch here.
|
||||
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def is_gradient_checkpointing_enabled(self):
|
||||
"""Check if gradient checkpointing is enabled."""
|
||||
return self.gradient_checkpointing_enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
||||
|
||||
def _preprocess_observation(self, observation, *, train=True):
|
||||
"""Helper method to preprocess observation."""
|
||||
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
||||
return (
|
||||
list(observation.images.values()),
|
||||
list(observation.image_masks.values()),
|
||||
observation.tokenized_prompt,
|
||||
observation.tokenized_prompt_mask,
|
||||
observation.state,
|
||||
)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
def image_embed_func(img):
|
||||
return self.paligemma_with_expert.embed_image(img)
|
||||
|
||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
|
||||
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
# Get batch size from the first dimension of the concatenated tensors
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
if not self.pi05:
|
||||
if self.state_proj.weight.dtype == torch.float32:
|
||||
state = state.to(torch.float32)
|
||||
|
||||
# Embed state
|
||||
def state_proj_func(state):
|
||||
return self.state_proj(state)
|
||||
|
||||
state_emb = self._apply_checkpoint(state_proj_func, state)
|
||||
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.action_in_proj.out_features,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=timestep.device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=timestep.dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
def action_proj_func(noisy_actions):
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
if not self.pi05:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
# Apply MLP layers
|
||||
def mlp_func(action_time_emb):
|
||||
x = self.action_time_mlp_in(action_time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
return self.action_time_mlp_out(x)
|
||||
|
||||
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
||||
adarms_cond = None
|
||||
else:
|
||||
# time MLP (for adaRMS)
|
||||
def time_mlp_func(time_emb):
|
||||
x = self.time_mlp_in(time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
x = self.time_mlp_out(x)
|
||||
return F.silu(x)
|
||||
|
||||
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
||||
action_time_emb = action_emb
|
||||
adarms_cond = time_emb
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=True
|
||||
)
|
||||
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# Apply gradient checkpointing if enabled
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
# Apply gradient checkpointing to final action projection if enabled
|
||||
def action_out_proj_func(suffix_out):
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = observation.state.shape[0]
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=False
|
||||
)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step - use new tensor assignment instead of in-place operation
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
179
tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py
Normal file
179
tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Constants moved from model.py
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
def preprocess_observation_pytorch(
|
||||
observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
):
|
||||
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
||||
|
||||
This function avoids complex type annotations that can cause torch.compile issues.
|
||||
"""
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
|
||||
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
# Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad_torch(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = image.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
||||
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=image.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
image = torch.nn.functional.grid_sample(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=image.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
image = image * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=image.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = image.mean(dim=[1, 2, 3], keepdim=True)
|
||||
image = (image - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=image.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = image.mean(dim=-1, keepdim=True)
|
||||
image = gray + (image - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
|
||||
# Back to [-1, 1]
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
|
||||
else:
|
||||
out_masks[key] = observation.image_masks[key]
|
||||
|
||||
# Create a simple object with the required attributes instead of using the complex Observation class
|
||||
class SimpleProcessedObservation:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return SimpleProcessedObservation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
101
tests/policies/pi0_pi05/test_pi05_compile.py
Normal file
101
tests/policies/pi0_pi05/test_pi05_compile.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions,
|
||||
compiled_model.sample_actions,
|
||||
sample_kwargs,
|
||||
"pi05.sample_actions",
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,52 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI05BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -96,341 +109,163 @@ class PI05BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = True
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi05(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
|
||||
else:
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
|
||||
|
||||
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||
config = PI05BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
def instantiate_original_pi05():
|
||||
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
|
||||
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
|
||||
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
|
||||
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名;LeRobot 的
|
||||
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
|
||||
# loader,所以这里手动复用同一份 tensor,避免把权重别名差异误判成模型差异。
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi05,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
|
||||
|
||||
class PI05Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description (pi05 processor handles all text formatting)
|
||||
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks] * batch_size
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
|
||||
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create pi05-formatted prompts that include state information
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi05.train()
|
||||
else:
|
||||
lerobot_pi05.eval()
|
||||
original_pi05.eval()
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi05):
|
||||
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi05_original_vs_lerobot():
|
||||
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi05.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
original_pi05.eval()
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi05.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
)
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
def test_pi05_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
|
||||
|
||||
def test_pi05_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi05_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi05_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
|
||||
99
tests/policies/pi0_pi05/test_pi0_compile.py
Normal file
99
tests/policies/pi0_pi05/test_pi0_compile.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
"state": torch.randn(1, config.max_state_dim, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,51 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -87,6 +92,15 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI0BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -95,333 +109,156 @@ class PI0BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = False
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
|
||||
else:
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
def instantiate_original_pi0():
|
||||
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi0,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
|
||||
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi0.train()
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi0):
|
||||
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
original_pi0.eval()
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
)
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
def test_pi0_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
|
||||
|
||||
def test_pi0_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi0_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi0_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
|
||||
1
tests/policies/pi0_pi05/utils/__init__.py
Normal file
1
tests/policies/pi0_pi05/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utilities shared by PI0/PI05 policy tests."""
|
||||
291
tests/policies/pi0_pi05/utils/openpi_parity.py
Normal file
291
tests/policies/pi0_pi05/utils/openpi_parity.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
|
||||
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
||||
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenPIObservation:
|
||||
state: torch.Tensor
|
||||
images: dict[str, torch.Tensor]
|
||||
image_masks: dict[str, torch.Tensor]
|
||||
tokenized_prompt: torch.Tensor
|
||||
tokenized_prompt_mask: torch.Tensor
|
||||
token_ar_mask: torch.Tensor
|
||||
token_loss_mask: torch.Tensor
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def paligemma_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
|
||||
def clone_batch(batch: dict) -> dict:
|
||||
return {
|
||||
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
|
||||
}
|
||||
|
||||
|
||||
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
|
||||
if tensor.shape[-1] > target_dim:
|
||||
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
|
||||
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
|
||||
|
||||
|
||||
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
return (tensor - mean) / (std + 1e-8)
|
||||
|
||||
|
||||
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
|
||||
def openpi_model_state_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
state = batch[OBS_STATE].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
state = quantile_normalize(state, dataset_stats[OBS_STATE])
|
||||
else:
|
||||
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
|
||||
return pad_last_dim(state, action_dim)
|
||||
|
||||
|
||||
def openpi_model_actions_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
actions = batch[ACTION].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
actions = quantile_normalize(actions, dataset_stats[ACTION])
|
||||
else:
|
||||
actions = mean_std_normalize(actions, dataset_stats[ACTION])
|
||||
return pad_last_dim(actions, action_dim)
|
||||
|
||||
|
||||
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("The parity batch must include a task prompt.")
|
||||
if isinstance(tasks, str):
|
||||
return [tasks] * batch_size
|
||||
if len(tasks) == 1:
|
||||
return [tasks[0]] * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
|
||||
return list(tasks)
|
||||
|
||||
|
||||
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
|
||||
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
|
||||
|
||||
|
||||
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
|
||||
state_np = normalized_state.detach().cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
prompts = []
|
||||
for task, state in zip(tasks, discretized_states, strict=True):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, state))
|
||||
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
|
||||
return prompts
|
||||
|
||||
|
||||
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
|
||||
tokenized = paligemma_tokenizer()(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=max_token_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = tokenized["input_ids"].to(device)
|
||||
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
return tokens, masks
|
||||
|
||||
|
||||
def make_openpi_observation_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
max_token_len: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> OpenPIObservation:
|
||||
batch_size = batch[OBS_STATE].shape[0]
|
||||
device = batch[OBS_STATE].device
|
||||
state = openpi_model_state_from_raw(
|
||||
batch,
|
||||
action_dim=action_dim,
|
||||
dataset_stats=dataset_stats,
|
||||
pi05=pi05,
|
||||
)
|
||||
|
||||
tasks = _tasks_from_raw(batch, batch_size)
|
||||
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
|
||||
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
|
||||
|
||||
images = {
|
||||
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
|
||||
for key in IMAGE_KEYS
|
||||
}
|
||||
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
|
||||
|
||||
return OpenPIObservation(
|
||||
state=state,
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
tokenized_prompt=tokens,
|
||||
tokenized_prompt_mask=masks,
|
||||
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
|
||||
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
|
||||
)
|
||||
|
||||
|
||||
def assert_processor_inputs_match_lerobot(
|
||||
lerobot_policy,
|
||||
lerobot_batch: dict[str, torch.Tensor],
|
||||
openpi_observation: OpenPIObservation,
|
||||
*,
|
||||
compare_state: bool,
|
||||
):
|
||||
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
|
||||
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
|
||||
|
||||
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
|
||||
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt_mask,
|
||||
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
|
||||
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
|
||||
|
||||
for openpi_mask, lerobot_mask in zip(
|
||||
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
|
||||
):
|
||||
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
|
||||
|
||||
if compare_state:
|
||||
torch.testing.assert_close(
|
||||
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
|
||||
)
|
||||
|
||||
|
||||
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
|
||||
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
|
||||
return safetensors.torch.load_file(cache_dir / "model.safetensors")
|
||||
|
||||
|
||||
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
fixed_state_dict = dict(state_dict)
|
||||
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
|
||||
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
|
||||
return fixed_state_dict
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
|
||||
original_sample_noise = model.sample_noise
|
||||
original_sample_time = model.sample_time
|
||||
|
||||
def sample_noise(shape, device):
|
||||
if tuple(shape) != tuple(noise.shape):
|
||||
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
|
||||
return noise.to(device=device)
|
||||
|
||||
def sample_time(batch_size, device):
|
||||
if batch_size != time.shape[0]:
|
||||
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
|
||||
return time.to(device=device)
|
||||
|
||||
model.sample_noise = sample_noise
|
||||
model.sample_time = sample_time
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.sample_noise = original_sample_noise
|
||||
model.sample_time = original_sample_time
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
|
||||
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
|
||||
|
||||
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
|
||||
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
|
||||
otherwise compare two different image tensors rather than two model implementations. The context manager
|
||||
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
|
||||
|
||||
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
|
||||
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
|
||||
"""
|
||||
|
||||
original_preprocess_observation = openpi_policy._preprocess_observation
|
||||
|
||||
def preprocess_observation(observation, *, train=True):
|
||||
return original_preprocess_observation(observation, train=False)
|
||||
|
||||
openpi_policy._preprocess_observation = preprocess_observation
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
openpi_policy._preprocess_observation = original_preprocess_observation
|
||||
207
tests/policies/pi0_pi05/utils/torch_compile.py
Normal file
207
tests/policies/pi0_pi05/utils/torch_compile.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, guard_failures
|
||||
from torch.profiler import ProfilerActivity
|
||||
|
||||
FORWARD_RTOL = 1e-5
|
||||
FORWARD_ATOL = 5e-2
|
||||
SAMPLE_RTOL = 1e-5
|
||||
SAMPLE_ATOL = 1e-2
|
||||
COMPILE_MODE = "max-autotune"
|
||||
STEADY_STATE_WARMUPS = 3
|
||||
STEADY_STATE_REPEATS = 3
|
||||
|
||||
|
||||
def make_compile_config(config_cls, *, compile_model):
|
||||
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
|
||||
|
||||
|
||||
def counter_total(name):
|
||||
return sum(counters.get(name, {}).values())
|
||||
|
||||
|
||||
def compile_snapshot():
|
||||
return {
|
||||
"graph_breaks": counter_total("graph_break"),
|
||||
"recompiles": counter_total("recompiles"),
|
||||
"recompile_limits": counter_total("recompile_limit"),
|
||||
"unique_graphs": counters["stats"].get("unique_graphs", 0),
|
||||
}
|
||||
|
||||
|
||||
def reset_compile_state():
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
guard_failures.clear()
|
||||
|
||||
|
||||
def clone_cuda_graph_output(output):
|
||||
if torch.is_tensor(output):
|
||||
return output.clone()
|
||||
if isinstance(output, tuple):
|
||||
return tuple(clone_cuda_graph_output(item) for item in output)
|
||||
if isinstance(output, list):
|
||||
return [clone_cuda_graph_output(item) for item in output]
|
||||
if isinstance(output, dict):
|
||||
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
|
||||
return output
|
||||
|
||||
|
||||
def run_model_step(fn: Callable, kwargs: dict):
|
||||
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return fn(**kwargs)
|
||||
|
||||
|
||||
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
explanation = torch._dynamo.explain(fn)(**kwargs)
|
||||
|
||||
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
|
||||
assert explanation.graph_break_count == 0, (
|
||||
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
|
||||
)
|
||||
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
|
||||
|
||||
print(
|
||||
f"{label} capture: graphs={explanation.graph_count}, "
|
||||
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
|
||||
f"guards={len(explanation.out_guards or [])}"
|
||||
)
|
||||
return explanation
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
|
||||
eager_forward = eager_model.forward(**forward_kwargs)
|
||||
compiled_forward = compiled_model.forward(**forward_kwargs)
|
||||
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
eager_actions = eager_model.sample_actions(**sample_kwargs)
|
||||
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
|
||||
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
|
||||
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
first_snapshot = compile_snapshot()
|
||||
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
second_snapshot = compile_snapshot()
|
||||
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
third_snapshot = compile_snapshot()
|
||||
|
||||
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
|
||||
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
|
||||
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
|
||||
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
|
||||
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
|
||||
)
|
||||
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
|
||||
)
|
||||
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
|
||||
|
||||
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
|
||||
run_warmups(eager_fn, kwargs)
|
||||
run_warmups(compiled_fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
eager_metrics = profile_callable(eager_fn, kwargs)
|
||||
compiled_metrics = profile_callable(compiled_fn, kwargs)
|
||||
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
|
||||
|
||||
print(
|
||||
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
|
||||
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
|
||||
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
|
||||
f"{compiled_metrics['host_wall_ms']:.3f}, "
|
||||
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
|
||||
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
|
||||
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
|
||||
f"{compiled_metrics['cuda_launch_count']}, "
|
||||
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
|
||||
f"{compiled_metrics['profiler_event_count']}, "
|
||||
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
|
||||
f"{compiled_metrics['peak_mem_mib']:.1f}"
|
||||
)
|
||||
|
||||
assert eager_metrics["cuda_event_ms"] > 0
|
||||
assert compiled_metrics["cuda_event_ms"] > 0
|
||||
assert eager_metrics["profiler_event_count"] > 0
|
||||
assert compiled_metrics["profiler_event_count"] > 0
|
||||
return eager_metrics, compiled_metrics
|
||||
|
||||
|
||||
def run_warmups(fn: Callable, kwargs: dict):
|
||||
for _ in range(STEADY_STATE_WARMUPS):
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_callable(fn: Callable, kwargs: dict):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
host_start = time.perf_counter()
|
||||
start_event.record()
|
||||
for _ in range(STEADY_STATE_REPEATS):
|
||||
run_model_step(fn, kwargs)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
|
||||
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
|
||||
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[ProfilerActivity.CPU],
|
||||
) as profiler:
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
key_averages = profiler.key_averages()
|
||||
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
|
||||
cuda_launch_count = sum(
|
||||
event.count
|
||||
for event in key_averages
|
||||
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
|
||||
)
|
||||
profiler_event_count = sum(event.count for event in key_averages)
|
||||
|
||||
return {
|
||||
"cuda_event_ms": cuda_event_ms,
|
||||
"host_wall_ms": host_wall_ms,
|
||||
"cpu_self_time_ms": cpu_self_time_ms,
|
||||
"cuda_launch_count": cuda_launch_count,
|
||||
"profiler_event_count": profiler_event_count,
|
||||
"peak_mem_mib": peak_mem_mib,
|
||||
}
|
||||
155
tests/processor/test_pi05_processor.py
Normal file
155
tests/processor/test_pi05_processor.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI05PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
|
||||
|
||||
_preprocess_images = PI05Policy._preprocess_images
|
||||
|
||||
def __init__(self, config: PI05Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi05_config() -> PI05Config:
|
||||
config = PI05Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi05_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi05_config()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI05PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
156
tests/processor/test_pi0_processor.py
Normal file
156
tests/processor/test_pi0_processor.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI0PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
|
||||
|
||||
_preprocess_images = PI0Policy._preprocess_images
|
||||
prepare_state = PI0Policy.prepare_state
|
||||
|
||||
def __init__(self, config: PI0Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi0_config() -> PI0Config:
|
||||
config = PI0Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi0_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi0_config()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI0PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.parser import (
|
||||
_config_path_args,
|
||||
_config_yaml_overrides,
|
||||
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
|
||||
|
||||
|
||||
def test_extract_path_fields_from_yaml():
|
||||
"""Test that policy.path is extracted from a YAML config and removed."""
|
||||
"""Test that policy.path is extracted from a YAML config and the policy block
|
||||
is removed entirely (siblings are captured separately as cli_overrides)."""
|
||||
config = {
|
||||
"dataset": {"repo_id": "lerobot/pusht"},
|
||||
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
||||
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
# Path should be extracted and stored
|
||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||
|
||||
# Cleaned config should not have the path field
|
||||
# Cleaned config should not have the policy block at all -- draccus must not
|
||||
# try to decode it as PreTrainedConfig; the actual config comes from
|
||||
# from_pretrained(path) with the captured overrides applied on top.
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = yaml.safe_load(f)
|
||||
assert "path" not in cleaned["policy"]
|
||||
assert cleaned["policy"]["type"] == "smolvla"
|
||||
assert cleaned["policy"]["push_to_hub"] is False
|
||||
assert "policy" not in cleaned
|
||||
|
||||
# Original dataset should be untouched
|
||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||
|
||||
# Sibling overrides (excluding type/path) captured for from_pretrained.
|
||||
overrides = get_yaml_overrides("policy")
|
||||
assert any("push_to_hub=false" in o for o in overrides)
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
def test_extract_path_fields_from_json():
|
||||
"""Test that policy.path is extracted from a JSON config."""
|
||||
"""Test that policy.path is extracted from a JSON config and the policy
|
||||
block is removed entirely."""
|
||||
config = {
|
||||
"policy": {"type": "act", "path": "some/local/path"},
|
||||
}
|
||||
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
assert _config_path_args["policy"] == "some/local/path"
|
||||
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = json.load(f)
|
||||
assert "path" not in cleaned["policy"]
|
||||
assert "policy" not in cleaned
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
def test_extract_no_path_returns_original():
|
||||
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
|
||||
args = _flatten_to_cli_args(d)
|
||||
assert "--optimizer.use_warmup=true" in args
|
||||
assert "--optimizer.lr=0.01" in args
|
||||
|
||||
|
||||
def test_extract_removes_field_with_siblings_and_no_type():
|
||||
"""Regression: when policy.path has siblings but no type:, the entire policy
|
||||
block must still be removed from the cleaned config. Otherwise draccus tries
|
||||
to decode the leftover dict as PreTrainedConfig and crashes on the missing
|
||||
type discriminator.
|
||||
"""
|
||||
config = {
|
||||
"dataset": {"repo_id": "lerobot/pusht"},
|
||||
"policy": {
|
||||
"path": "lerobot/smolvla_base",
|
||||
"n_action_steps": 10,
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
yaml.dump(config, f)
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = yaml.safe_load(f) or {}
|
||||
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
|
||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||
overrides = get_yaml_overrides("policy")
|
||||
assert any("n_action_steps=10" in o for o in overrides)
|
||||
assert any("dtype=bfloat16" in o for o in overrides)
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DummyNested:
|
||||
foo: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DummyConfig:
|
||||
nested: _DummyNested = field(default_factory=_DummyNested)
|
||||
other: str = "default"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls):
|
||||
return ["nested"]
|
||||
|
||||
|
||||
def test_wrap_uses_cleaned_config_for_draccus_parse():
|
||||
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
|
||||
file but must propagate that to the draccus.parse fallback branch. Without
|
||||
the fix, cli_args still contains --config_path=<original> and draccus reads
|
||||
the original YAML with `path:` still in it, crashing on the unknown field.
|
||||
"""
|
||||
config = {
|
||||
"nested": {"path": "some/checkpoint", "foo": 42},
|
||||
"other": "set-via-yaml",
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
yaml.dump(config, f)
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: _DummyConfig) -> _DummyConfig:
|
||||
captured["cfg"] = cfg
|
||||
return cfg
|
||||
|
||||
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
|
||||
main()
|
||||
|
||||
assert captured["cfg"].other == "set-via-yaml"
|
||||
assert _config_path_args["nested"] == "some/checkpoint"
|
||||
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
|
||||
# class (a real PreTrainedConfig would now load the checkpoint and apply
|
||||
# the captured yaml_overrides via from_pretrained()).
|
||||
assert captured["cfg"].nested.foo == 0
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
22
uv.lock
generated
22
uv.lock
generated
@@ -3203,7 +3203,7 @@ requires-dist = [
|
||||
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
|
||||
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
|
||||
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
||||
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
||||
@@ -4592,7 +4592,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "placo"
|
||||
version = "0.9.16"
|
||||
version = "0.9.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cmeel" },
|
||||
@@ -4602,16 +4602,16 @@ dependencies = [
|
||||
{ name = "pin" },
|
||||
{ name = "rhoban-cmeel-jsoncpp" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user