mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
8 Commits
docs/model
...
5c98e80430
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c98e80430 | ||
|
|
f65f3f7a4a | ||
|
|
8194897994 | ||
|
|
9f437d86b6 | ||
|
|
b74a551d38 | ||
|
|
c0a2e9814d | ||
|
|
bac4f61eae | ||
|
|
f4b834844e |
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,14 +121,12 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -138,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
else:
|
||||
del config_data[field]
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -44,6 +57,7 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -43,6 +43,7 @@ from .tables import (
|
||||
CAN_CMD_SET_ZERO,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
HANDSHAKE_TIMEOUT_S,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||
def _query_status_via_clear_fault(
|
||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||
) -> tuple[bool, can.Message | None]:
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||
|
||||
def _recv_status_via_clear_fault(
|
||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
faulted_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
elif has_fault:
|
||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
|
||||
return responses
|
||||
|
||||
def _recv_all_messages_until_quiet(
|
||||
self,
|
||||
*,
|
||||
timeout: float = RUNNING_TIMEOUT,
|
||||
max_messages: int = 4096,
|
||||
) -> list[can.Message]:
|
||||
"""
|
||||
Receive frames until the bus goes quiet.
|
||||
|
||||
Args:
|
||||
timeout: Poll timeout used for each recv() call. Collection stops
|
||||
when one recv() times out (quiet gap).
|
||||
max_messages: Safety cap to prevent unbounded loops.
|
||||
"""
|
||||
out: list[can.Message] = []
|
||||
max_messages = max(1, max_messages)
|
||||
timeout = max(0.0, timeout)
|
||||
|
||||
try:
|
||||
while len(out) < max_messages:
|
||||
msg = self._bus().recv(timeout=timeout)
|
||||
if msg is None:
|
||||
break
|
||||
out.append(msg)
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||
|
||||
return out
|
||||
|
||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||
"""
|
||||
Decode all received feedback frames and update cached motor states.
|
||||
|
||||
Returns:
|
||||
Set of payload recv_ids that were successfully mapped to motors.
|
||||
"""
|
||||
processed_recv_ids: set[int] = set()
|
||||
for msg in messages:
|
||||
if len(msg.data) < 1:
|
||||
logger.debug(
|
||||
f"Dropping short CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
recv_id = int(msg.data[0])
|
||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||
if motor_name is None:
|
||||
logger.debug(
|
||||
f"Unmapped CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
self._process_response(motor_name, msg)
|
||||
processed_recv_ids.add(recv_id)
|
||||
|
||||
return processed_recv_ids
|
||||
|
||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||
"""
|
||||
Drain pending RX frames from the CAN interface.
|
||||
|
||||
This is used by higher-level controllers to drop stale feedback before issuing
|
||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||
It should also be called once when a controller instance is created/connected,
|
||||
to clear residual frames left on the interface from previous sessions.
|
||||
"""
|
||||
drained = 0
|
||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||
max_messages = max(1, max_messages)
|
||||
try:
|
||||
while drained < max_messages:
|
||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||
if msg is None:
|
||||
break
|
||||
drained += 1
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||
return drained
|
||||
|
||||
def _speed_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||
# This avoids dropping useful frames when responses from different motors interleave.
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
try:
|
||||
self._decode_motor_state(msg.data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to decode response from {motor} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||
)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the state cache."""
|
||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._bus().send(msg)
|
||||
updated_motors.append(motor)
|
||||
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||
|
||||
for response in responses.values():
|
||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||
if payload_motor_name is not None:
|
||||
self._process_response(payload_motor_name, response)
|
||||
else:
|
||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||
self._decode_motor_state(response.data)
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
for motor in updated_motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if recv_id not in responses:
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
|
||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
RUNNING_TIMEOUT = 0.003
|
||||
HANDSHAKE_TIMEOUT_S = 0.05
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
|
||||
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc = AutoProcessor.from_pretrained(
|
||||
str(cache_dir),
|
||||
trust_remote_code=True,
|
||||
fix_mistral_regex=False,
|
||||
)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -255,15 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -271,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
1
tests/policies/pi0_pi05/openpi_pytorch/__init__.py
Normal file
1
tests/policies/pi0_pi05/openpi_pytorch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""
|
||||
22
tests/policies/pi0_pi05/openpi_pytorch/gemma.py
Normal file
22
tests/policies/pi0_pi05/openpi_pytorch/gemma.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
|
||||
|
||||
def get_config(variant: str) -> Config:
|
||||
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
||||
if variant == "dummy":
|
||||
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
||||
if variant == "gemma_300m":
|
||||
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
if variant == "gemma_2b":
|
||||
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
300
tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py
Normal file
300
tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
|
||||
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
||||
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
if not self.gemma_expert.model.gradient_checkpointing:
|
||||
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
||||
self.gemma_expert.model.gradient_checkpointing = True
|
||||
use_gradient_checkpointing = True
|
||||
|
||||
# Debug gradient checkpointing status
|
||||
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
||||
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
||||
print(f"Model training mode: {self.training}")
|
||||
print(
|
||||
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
||||
)
|
||||
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
print(
|
||||
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
||||
)
|
||||
self._debug_gc_printed = True
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(
|
||||
layer.input_layernorm, hidden_states, adarms_cond[i]
|
||||
)
|
||||
gates.append(gate)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
self.paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
|
||||
return outputs_embeds
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
# Old code removed - now using compute_layer_complete function above
|
||||
|
||||
# final norm
|
||||
# Define final norm computation function for gradient checkpointing
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms,
|
||||
inputs_embeds,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
79
tests/policies/pi0_pi05/openpi_pytorch/image_tools.py
Normal file
79
tests/policies/pi0_pi05/openpi_pytorch/image_tools.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def resize_with_pad_torch(
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
mode: str = "bilinear",
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
||||
|
||||
Args:
|
||||
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
||||
height: Target height
|
||||
width: Target width
|
||||
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
||||
|
||||
Returns:
|
||||
Resized and padded tensor with same shape format as input
|
||||
"""
|
||||
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
||||
if images.shape[-1] <= 4: # Assume channels-last format
|
||||
channels_last = True
|
||||
# Convert to channels-first for torch operations
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
||||
else:
|
||||
channels_last = False
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
|
||||
batch_size, channels, cur_height, cur_width = images.shape
|
||||
|
||||
# Calculate resize ratio
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
# Resize
|
||||
resized_images = F.interpolate(
|
||||
images,
|
||||
size=(resized_height, resized_width),
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
)
|
||||
|
||||
# Handle dtype-specific clipping
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
# Calculate padding
|
||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
||||
pad_h1 = pad_h0 + remainder_h
|
||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
mode="constant",
|
||||
value=constant_value,
|
||||
)
|
||||
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
471
tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py
Normal file
471
tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
class PI0Pytorch(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.pi05 = config.pi05
|
||||
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True] if self.pi05 else [False, False],
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
||||
|
||||
if self.pi05:
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
else:
|
||||
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if config.pytorch_compile_mode is not None:
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
# The upstream OpenPI module verifies a site-package Transformers patch here.
|
||||
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def is_gradient_checkpointing_enabled(self):
|
||||
"""Check if gradient checkpointing is enabled."""
|
||||
return self.gradient_checkpointing_enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
||||
|
||||
def _preprocess_observation(self, observation, *, train=True):
|
||||
"""Helper method to preprocess observation."""
|
||||
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
||||
return (
|
||||
list(observation.images.values()),
|
||||
list(observation.image_masks.values()),
|
||||
observation.tokenized_prompt,
|
||||
observation.tokenized_prompt_mask,
|
||||
observation.state,
|
||||
)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
def image_embed_func(img):
|
||||
return self.paligemma_with_expert.embed_image(img)
|
||||
|
||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
|
||||
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
# Get batch size from the first dimension of the concatenated tensors
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
if not self.pi05:
|
||||
if self.state_proj.weight.dtype == torch.float32:
|
||||
state = state.to(torch.float32)
|
||||
|
||||
# Embed state
|
||||
def state_proj_func(state):
|
||||
return self.state_proj(state)
|
||||
|
||||
state_emb = self._apply_checkpoint(state_proj_func, state)
|
||||
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.action_in_proj.out_features,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=timestep.device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=timestep.dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
def action_proj_func(noisy_actions):
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
if not self.pi05:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
# Apply MLP layers
|
||||
def mlp_func(action_time_emb):
|
||||
x = self.action_time_mlp_in(action_time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
return self.action_time_mlp_out(x)
|
||||
|
||||
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
||||
adarms_cond = None
|
||||
else:
|
||||
# time MLP (for adaRMS)
|
||||
def time_mlp_func(time_emb):
|
||||
x = self.time_mlp_in(time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
x = self.time_mlp_out(x)
|
||||
return F.silu(x)
|
||||
|
||||
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
||||
action_time_emb = action_emb
|
||||
adarms_cond = time_emb
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=True
|
||||
)
|
||||
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# Apply gradient checkpointing if enabled
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
# Apply gradient checkpointing to final action projection if enabled
|
||||
def action_out_proj_func(suffix_out):
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = observation.state.shape[0]
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=False
|
||||
)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step - use new tensor assignment instead of in-place operation
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
179
tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py
Normal file
179
tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Constants moved from model.py
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
def preprocess_observation_pytorch(
|
||||
observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
):
|
||||
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
||||
|
||||
This function avoids complex type annotations that can cause torch.compile issues.
|
||||
"""
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
|
||||
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
# Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad_torch(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = image.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
||||
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=image.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
image = torch.nn.functional.grid_sample(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=image.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
image = image * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=image.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = image.mean(dim=[1, 2, 3], keepdim=True)
|
||||
image = (image - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=image.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = image.mean(dim=-1, keepdim=True)
|
||||
image = gray + (image - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
|
||||
# Back to [-1, 1]
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
|
||||
else:
|
||||
out_masks[key] = observation.image_masks[key]
|
||||
|
||||
# Create a simple object with the required attributes instead of using the complex Observation class
|
||||
class SimpleProcessedObservation:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return SimpleProcessedObservation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
101
tests/policies/pi0_pi05/test_pi05_compile.py
Normal file
101
tests/policies/pi0_pi05/test_pi05_compile.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions,
|
||||
compiled_model.sample_actions,
|
||||
sample_kwargs,
|
||||
"pi05.sample_actions",
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,52 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI05BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -96,341 +109,163 @@ class PI05BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = True
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi05(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
|
||||
else:
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
|
||||
|
||||
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||
config = PI05BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
def instantiate_original_pi05():
|
||||
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
|
||||
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
|
||||
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
|
||||
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名;LeRobot 的
|
||||
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
|
||||
# loader,所以这里手动复用同一份 tensor,避免把权重别名差异误判成模型差异。
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi05,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
|
||||
|
||||
class PI05Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description (pi05 processor handles all text formatting)
|
||||
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks] * batch_size
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
|
||||
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create pi05-formatted prompts that include state information
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi05.train()
|
||||
else:
|
||||
lerobot_pi05.eval()
|
||||
original_pi05.eval()
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi05):
|
||||
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi05_original_vs_lerobot():
|
||||
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi05.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
original_pi05.eval()
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi05.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
)
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
def test_pi05_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
|
||||
|
||||
def test_pi05_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi05_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi05_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
|
||||
99
tests/policies/pi0_pi05/test_pi0_compile.py
Normal file
99
tests/policies/pi0_pi05/test_pi0_compile.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
"state": torch.randn(1, config.max_state_dim, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,51 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -87,6 +92,15 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI0BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -95,333 +109,156 @@ class PI0BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = False
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
|
||||
else:
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
def instantiate_original_pi0():
|
||||
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi0,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
|
||||
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi0.train()
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi0):
|
||||
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
original_pi0.eval()
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
)
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
def test_pi0_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
|
||||
|
||||
def test_pi0_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi0_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi0_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
|
||||
1
tests/policies/pi0_pi05/utils/__init__.py
Normal file
1
tests/policies/pi0_pi05/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utilities shared by PI0/PI05 policy tests."""
|
||||
291
tests/policies/pi0_pi05/utils/openpi_parity.py
Normal file
291
tests/policies/pi0_pi05/utils/openpi_parity.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
|
||||
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
||||
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenPIObservation:
|
||||
state: torch.Tensor
|
||||
images: dict[str, torch.Tensor]
|
||||
image_masks: dict[str, torch.Tensor]
|
||||
tokenized_prompt: torch.Tensor
|
||||
tokenized_prompt_mask: torch.Tensor
|
||||
token_ar_mask: torch.Tensor
|
||||
token_loss_mask: torch.Tensor
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def paligemma_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
|
||||
def clone_batch(batch: dict) -> dict:
|
||||
return {
|
||||
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
|
||||
}
|
||||
|
||||
|
||||
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
|
||||
if tensor.shape[-1] > target_dim:
|
||||
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
|
||||
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
|
||||
|
||||
|
||||
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
return (tensor - mean) / (std + 1e-8)
|
||||
|
||||
|
||||
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
|
||||
def openpi_model_state_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
state = batch[OBS_STATE].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
state = quantile_normalize(state, dataset_stats[OBS_STATE])
|
||||
else:
|
||||
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
|
||||
return pad_last_dim(state, action_dim)
|
||||
|
||||
|
||||
def openpi_model_actions_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
actions = batch[ACTION].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
actions = quantile_normalize(actions, dataset_stats[ACTION])
|
||||
else:
|
||||
actions = mean_std_normalize(actions, dataset_stats[ACTION])
|
||||
return pad_last_dim(actions, action_dim)
|
||||
|
||||
|
||||
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("The parity batch must include a task prompt.")
|
||||
if isinstance(tasks, str):
|
||||
return [tasks] * batch_size
|
||||
if len(tasks) == 1:
|
||||
return [tasks[0]] * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
|
||||
return list(tasks)
|
||||
|
||||
|
||||
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
|
||||
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
|
||||
|
||||
|
||||
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
|
||||
state_np = normalized_state.detach().cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
prompts = []
|
||||
for task, state in zip(tasks, discretized_states, strict=True):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, state))
|
||||
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
|
||||
return prompts
|
||||
|
||||
|
||||
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
|
||||
tokenized = paligemma_tokenizer()(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=max_token_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = tokenized["input_ids"].to(device)
|
||||
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
return tokens, masks
|
||||
|
||||
|
||||
def make_openpi_observation_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
max_token_len: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> OpenPIObservation:
|
||||
batch_size = batch[OBS_STATE].shape[0]
|
||||
device = batch[OBS_STATE].device
|
||||
state = openpi_model_state_from_raw(
|
||||
batch,
|
||||
action_dim=action_dim,
|
||||
dataset_stats=dataset_stats,
|
||||
pi05=pi05,
|
||||
)
|
||||
|
||||
tasks = _tasks_from_raw(batch, batch_size)
|
||||
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
|
||||
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
|
||||
|
||||
images = {
|
||||
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
|
||||
for key in IMAGE_KEYS
|
||||
}
|
||||
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
|
||||
|
||||
return OpenPIObservation(
|
||||
state=state,
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
tokenized_prompt=tokens,
|
||||
tokenized_prompt_mask=masks,
|
||||
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
|
||||
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
|
||||
)
|
||||
|
||||
|
||||
def assert_processor_inputs_match_lerobot(
|
||||
lerobot_policy,
|
||||
lerobot_batch: dict[str, torch.Tensor],
|
||||
openpi_observation: OpenPIObservation,
|
||||
*,
|
||||
compare_state: bool,
|
||||
):
|
||||
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
|
||||
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
|
||||
|
||||
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
|
||||
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt_mask,
|
||||
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
|
||||
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
|
||||
|
||||
for openpi_mask, lerobot_mask in zip(
|
||||
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
|
||||
):
|
||||
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
|
||||
|
||||
if compare_state:
|
||||
torch.testing.assert_close(
|
||||
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
|
||||
)
|
||||
|
||||
|
||||
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
|
||||
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
|
||||
return safetensors.torch.load_file(cache_dir / "model.safetensors")
|
||||
|
||||
|
||||
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
fixed_state_dict = dict(state_dict)
|
||||
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
|
||||
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
|
||||
return fixed_state_dict
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
|
||||
original_sample_noise = model.sample_noise
|
||||
original_sample_time = model.sample_time
|
||||
|
||||
def sample_noise(shape, device):
|
||||
if tuple(shape) != tuple(noise.shape):
|
||||
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
|
||||
return noise.to(device=device)
|
||||
|
||||
def sample_time(batch_size, device):
|
||||
if batch_size != time.shape[0]:
|
||||
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
|
||||
return time.to(device=device)
|
||||
|
||||
model.sample_noise = sample_noise
|
||||
model.sample_time = sample_time
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.sample_noise = original_sample_noise
|
||||
model.sample_time = original_sample_time
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
|
||||
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
|
||||
|
||||
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
|
||||
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
|
||||
otherwise compare two different image tensors rather than two model implementations. The context manager
|
||||
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
|
||||
|
||||
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
|
||||
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
|
||||
"""
|
||||
|
||||
original_preprocess_observation = openpi_policy._preprocess_observation
|
||||
|
||||
def preprocess_observation(observation, *, train=True):
|
||||
return original_preprocess_observation(observation, train=False)
|
||||
|
||||
openpi_policy._preprocess_observation = preprocess_observation
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
openpi_policy._preprocess_observation = original_preprocess_observation
|
||||
207
tests/policies/pi0_pi05/utils/torch_compile.py
Normal file
207
tests/policies/pi0_pi05/utils/torch_compile.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, guard_failures
|
||||
from torch.profiler import ProfilerActivity
|
||||
|
||||
FORWARD_RTOL = 1e-5
|
||||
FORWARD_ATOL = 5e-2
|
||||
SAMPLE_RTOL = 1e-5
|
||||
SAMPLE_ATOL = 1e-2
|
||||
COMPILE_MODE = "max-autotune"
|
||||
STEADY_STATE_WARMUPS = 3
|
||||
STEADY_STATE_REPEATS = 3
|
||||
|
||||
|
||||
def make_compile_config(config_cls, *, compile_model):
|
||||
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
|
||||
|
||||
|
||||
def counter_total(name):
|
||||
return sum(counters.get(name, {}).values())
|
||||
|
||||
|
||||
def compile_snapshot():
|
||||
return {
|
||||
"graph_breaks": counter_total("graph_break"),
|
||||
"recompiles": counter_total("recompiles"),
|
||||
"recompile_limits": counter_total("recompile_limit"),
|
||||
"unique_graphs": counters["stats"].get("unique_graphs", 0),
|
||||
}
|
||||
|
||||
|
||||
def reset_compile_state():
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
guard_failures.clear()
|
||||
|
||||
|
||||
def clone_cuda_graph_output(output):
|
||||
if torch.is_tensor(output):
|
||||
return output.clone()
|
||||
if isinstance(output, tuple):
|
||||
return tuple(clone_cuda_graph_output(item) for item in output)
|
||||
if isinstance(output, list):
|
||||
return [clone_cuda_graph_output(item) for item in output]
|
||||
if isinstance(output, dict):
|
||||
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
|
||||
return output
|
||||
|
||||
|
||||
def run_model_step(fn: Callable, kwargs: dict):
|
||||
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return fn(**kwargs)
|
||||
|
||||
|
||||
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
explanation = torch._dynamo.explain(fn)(**kwargs)
|
||||
|
||||
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
|
||||
assert explanation.graph_break_count == 0, (
|
||||
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
|
||||
)
|
||||
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
|
||||
|
||||
print(
|
||||
f"{label} capture: graphs={explanation.graph_count}, "
|
||||
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
|
||||
f"guards={len(explanation.out_guards or [])}"
|
||||
)
|
||||
return explanation
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
|
||||
eager_forward = eager_model.forward(**forward_kwargs)
|
||||
compiled_forward = compiled_model.forward(**forward_kwargs)
|
||||
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
eager_actions = eager_model.sample_actions(**sample_kwargs)
|
||||
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
|
||||
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
|
||||
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
first_snapshot = compile_snapshot()
|
||||
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
second_snapshot = compile_snapshot()
|
||||
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
third_snapshot = compile_snapshot()
|
||||
|
||||
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
|
||||
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
|
||||
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
|
||||
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
|
||||
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
|
||||
)
|
||||
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
|
||||
)
|
||||
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
|
||||
|
||||
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
|
||||
run_warmups(eager_fn, kwargs)
|
||||
run_warmups(compiled_fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
eager_metrics = profile_callable(eager_fn, kwargs)
|
||||
compiled_metrics = profile_callable(compiled_fn, kwargs)
|
||||
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
|
||||
|
||||
print(
|
||||
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
|
||||
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
|
||||
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
|
||||
f"{compiled_metrics['host_wall_ms']:.3f}, "
|
||||
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
|
||||
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
|
||||
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
|
||||
f"{compiled_metrics['cuda_launch_count']}, "
|
||||
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
|
||||
f"{compiled_metrics['profiler_event_count']}, "
|
||||
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
|
||||
f"{compiled_metrics['peak_mem_mib']:.1f}"
|
||||
)
|
||||
|
||||
assert eager_metrics["cuda_event_ms"] > 0
|
||||
assert compiled_metrics["cuda_event_ms"] > 0
|
||||
assert eager_metrics["profiler_event_count"] > 0
|
||||
assert compiled_metrics["profiler_event_count"] > 0
|
||||
return eager_metrics, compiled_metrics
|
||||
|
||||
|
||||
def run_warmups(fn: Callable, kwargs: dict):
|
||||
for _ in range(STEADY_STATE_WARMUPS):
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_callable(fn: Callable, kwargs: dict):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
host_start = time.perf_counter()
|
||||
start_event.record()
|
||||
for _ in range(STEADY_STATE_REPEATS):
|
||||
run_model_step(fn, kwargs)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
|
||||
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
|
||||
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[ProfilerActivity.CPU],
|
||||
) as profiler:
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
key_averages = profiler.key_averages()
|
||||
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
|
||||
cuda_launch_count = sum(
|
||||
event.count
|
||||
for event in key_averages
|
||||
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
|
||||
)
|
||||
profiler_event_count = sum(event.count for event in key_averages)
|
||||
|
||||
return {
|
||||
"cuda_event_ms": cuda_event_ms,
|
||||
"host_wall_ms": host_wall_ms,
|
||||
"cpu_self_time_ms": cpu_self_time_ms,
|
||||
"cuda_launch_count": cuda_launch_count,
|
||||
"profiler_event_count": profiler_event_count,
|
||||
"peak_mem_mib": peak_mem_mib,
|
||||
}
|
||||
155
tests/processor/test_pi05_processor.py
Normal file
155
tests/processor/test_pi05_processor.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI05PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
|
||||
|
||||
_preprocess_images = PI05Policy._preprocess_images
|
||||
|
||||
def __init__(self, config: PI05Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi05_config() -> PI05Config:
|
||||
config = PI05Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi05_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi05_config()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI05PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
156
tests/processor/test_pi0_processor.py
Normal file
156
tests/processor/test_pi0_processor.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI0PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
|
||||
|
||||
_preprocess_images = PI0Policy._preprocess_images
|
||||
prepare_state = PI0Policy.prepare_state
|
||||
|
||||
def __init__(self, config: PI0Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi0_config() -> PI0Config:
|
||||
config = PI0Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi0_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi0_config()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI0PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.parser import (
|
||||
_config_path_args,
|
||||
_config_yaml_overrides,
|
||||
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
|
||||
|
||||
|
||||
def test_extract_path_fields_from_yaml():
|
||||
"""Test that policy.path is extracted from a YAML config and removed."""
|
||||
"""Test that policy.path is extracted from a YAML config and the policy block
|
||||
is removed entirely (siblings are captured separately as cli_overrides)."""
|
||||
config = {
|
||||
"dataset": {"repo_id": "lerobot/pusht"},
|
||||
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
||||
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
# Path should be extracted and stored
|
||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||
|
||||
# Cleaned config should not have the path field
|
||||
# Cleaned config should not have the policy block at all -- draccus must not
|
||||
# try to decode it as PreTrainedConfig; the actual config comes from
|
||||
# from_pretrained(path) with the captured overrides applied on top.
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = yaml.safe_load(f)
|
||||
assert "path" not in cleaned["policy"]
|
||||
assert cleaned["policy"]["type"] == "smolvla"
|
||||
assert cleaned["policy"]["push_to_hub"] is False
|
||||
assert "policy" not in cleaned
|
||||
|
||||
# Original dataset should be untouched
|
||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||
|
||||
# Sibling overrides (excluding type/path) captured for from_pretrained.
|
||||
overrides = get_yaml_overrides("policy")
|
||||
assert any("push_to_hub=false" in o for o in overrides)
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
def test_extract_path_fields_from_json():
|
||||
"""Test that policy.path is extracted from a JSON config."""
|
||||
"""Test that policy.path is extracted from a JSON config and the policy
|
||||
block is removed entirely."""
|
||||
config = {
|
||||
"policy": {"type": "act", "path": "some/local/path"},
|
||||
}
|
||||
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
assert _config_path_args["policy"] == "some/local/path"
|
||||
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = json.load(f)
|
||||
assert "path" not in cleaned["policy"]
|
||||
assert "policy" not in cleaned
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
def test_extract_no_path_returns_original():
|
||||
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
|
||||
args = _flatten_to_cli_args(d)
|
||||
assert "--optimizer.use_warmup=true" in args
|
||||
assert "--optimizer.lr=0.01" in args
|
||||
|
||||
|
||||
def test_extract_removes_field_with_siblings_and_no_type():
|
||||
"""Regression: when policy.path has siblings but no type:, the entire policy
|
||||
block must still be removed from the cleaned config. Otherwise draccus tries
|
||||
to decode the leftover dict as PreTrainedConfig and crashes on the missing
|
||||
type discriminator.
|
||||
"""
|
||||
config = {
|
||||
"dataset": {"repo_id": "lerobot/pusht"},
|
||||
"policy": {
|
||||
"path": "lerobot/smolvla_base",
|
||||
"n_action_steps": 10,
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
yaml.dump(config, f)
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||
|
||||
with open(cleaned_path) as f:
|
||||
cleaned = yaml.safe_load(f) or {}
|
||||
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
|
||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||
overrides = get_yaml_overrides("policy")
|
||||
assert any("n_action_steps=10" in o for o in overrides)
|
||||
assert any("dtype=bfloat16" in o for o in overrides)
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DummyNested:
|
||||
foo: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DummyConfig:
|
||||
nested: _DummyNested = field(default_factory=_DummyNested)
|
||||
other: str = "default"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls):
|
||||
return ["nested"]
|
||||
|
||||
|
||||
def test_wrap_uses_cleaned_config_for_draccus_parse():
|
||||
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
|
||||
file but must propagate that to the draccus.parse fallback branch. Without
|
||||
the fix, cli_args still contains --config_path=<original> and draccus reads
|
||||
the original YAML with `path:` still in it, crashing on the unknown field.
|
||||
"""
|
||||
config = {
|
||||
"nested": {"path": "some/checkpoint", "foo": 42},
|
||||
"other": "set-via-yaml",
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
yaml.dump(config, f)
|
||||
config_path = f.name
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: _DummyConfig) -> _DummyConfig:
|
||||
captured["cfg"] = cfg
|
||||
return cfg
|
||||
|
||||
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
|
||||
main()
|
||||
|
||||
assert captured["cfg"].other == "set-via-yaml"
|
||||
assert _config_path_args["nested"] == "some/checkpoint"
|
||||
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
|
||||
# class (a real PreTrainedConfig would now load the checkpoint and apply
|
||||
# the captured yaml_overrides via from_pretrained()).
|
||||
assert captured["cfg"].nested.foo == 0
|
||||
|
||||
_config_path_args.clear()
|
||||
_config_yaml_overrides.clear()
|
||||
|
||||
22
uv.lock
generated
22
uv.lock
generated
@@ -3203,7 +3203,7 @@ requires-dist = [
|
||||
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
|
||||
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
|
||||
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
||||
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
||||
@@ -4592,7 +4592,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "placo"
|
||||
version = "0.9.16"
|
||||
version = "0.9.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cmeel" },
|
||||
@@ -4602,16 +4602,16 @@ dependencies = [
|
||||
{ name = "pin" },
|
||||
{ name = "rhoban-cmeel-jsoncpp" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user