mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Compare commits
33 Commits
chore/add-
...
b75b3ce02d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b75b3ce02d | ||
|
|
5495c10cdf | ||
|
|
1bcba9dec6 | ||
|
|
dd13eda002 | ||
|
|
c01a00a972 | ||
|
|
f75a2ee2f5 | ||
|
|
0edbb68ec3 | ||
|
|
47f8a50fa0 | ||
|
|
51e57789ba | ||
|
|
27669724d2 | ||
|
|
e32a552edb | ||
|
|
596fda60d7 | ||
|
|
3fcea935b2 | ||
|
|
76b63ebb26 | ||
|
|
bbe4ba7a53 | ||
|
|
593253e155 | ||
|
|
acf65faaff | ||
|
|
82a05f9cb4 | ||
|
|
d4abb9d562 | ||
|
|
090d392b19 | ||
|
|
e36d742d7d | ||
|
|
f8a1acb6c9 | ||
|
|
60347bc742 | ||
|
|
c6ec8d00e3 | ||
|
|
0edb693ee4 | ||
|
|
cdae1b9ad8 | ||
|
|
80ecf7bf53 | ||
|
|
5597d539e7 | ||
|
|
dfbedb71d7 | ||
|
|
ebe6c66263 | ||
|
|
0e18bdaf7a | ||
|
|
d5944c410c | ||
|
|
0d37efdb4b |
11
.github/dependabot.yml
vendored
11
.github/dependabot.yml
vendored
@@ -1,11 +0,0 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 7
|
||||
groups:
|
||||
actions:
|
||||
patterns: ["*"]
|
||||
@@ -79,13 +79,17 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -105,12 +105,10 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
|
||||
```bash
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
lerobot-record \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -121,12 +119,14 @@ lerobot-rollout\
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--duration=600
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,48 +122,34 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
)
|
||||
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -216,11 +202,10 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -233,56 +218,71 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,50 +291,26 @@ def main():
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -372,7 +348,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -446,7 +422,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -514,83 +490,6 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
196
docs/source/policy_vla_jepa_README.md
Normal file
196
docs/source/policy_vla_jepa_README.md
Normal file
@@ -0,0 +1,196 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 | Disabled\* | 7 |
|
||||
|
||||
\* The SimplerEnv checkpoint was fine-tuned from Pretrain. The world model predictor architecture expects `embed_dim=2048` (2-camera input) but SimplerEnv is single-camera, so the world model cannot be loaded cleanly. Since inference only needs Qwen + the action head, `enable_world_model=False` is set for this variant. See [Fine-tuning on single-camera datasets](#fine-tuning-on-single-camera-datasets) for implications.
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | -------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5 \
|
||||
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5 \
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on single-camera datasets
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = num_views × 1024`. If your target dataset has fewer cameras than the source checkpoint, the predictor input projection will have a shape mismatch and cannot be loaded.
|
||||
|
||||
**Option 1 — Disable the world model (recommended)**
|
||||
|
||||
Set `enable_world_model=False`. Only the Qwen backbone and action head are loaded and trained. This matches the original SimplerEnv fine-tuning strategy and is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
**Option 3 - Duplicate frames to match the expected number of cameras**
|
||||
A bit more advanced, you would need to change some parts of the code to support that.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -15,12 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -58,26 +56,22 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -221,28 +215,25 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / progress_file
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
if not parquet_path.exists():
|
||||
logging.warning("%s not found", progress_file)
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -585,7 +576,6 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -595,8 +585,6 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -604,7 +592,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -612,9 +600,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -639,7 +627,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -670,15 +658,6 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -691,7 +670,6 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -138,9 +138,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
@@ -214,6 +212,7 @@ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<3
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -278,6 +277,7 @@ all = [
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
|
||||
@@ -18,25 +18,12 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
placo = None
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -57,7 +44,6 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -56,6 +56,7 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -406,6 +413,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -414,6 +422,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,14 +26,9 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -178,20 +173,19 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -142,15 +141,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -237,13 +227,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = layers[i]
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -265,16 +258,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma_layer.self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -282,13 +274,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = layers[i]
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -496,9 +488,8 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -508,39 +499,36 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -919,7 +907,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -139,15 +138,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -234,13 +224,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = layers[i]
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -262,16 +255,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma_layer.self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -279,13 +271,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = layers[i]
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -493,9 +485,8 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -505,39 +496,36 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -892,7 +880,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
/home/maxime/github/robots/lerobot/docs/source/policy_vla_jepa_README.md
|
||||
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"VLAJEPANewLineProcessor",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
327
src/lerobot/policies/vla_jepa/action_head.py
Normal file
327
src/lerobot/policies/vla_jepa/action_head.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
133
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
133
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
reinit_action_head: bool = False
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
454
src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py
Normal file
454
src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py
Normal file
@@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Convert all VLA-JEPA .pt checkpoints (ginwind/VLA-JEPA) to LeRobot safetensors
|
||||
format and upload them to maximellerbach org inside a HF collection.
|
||||
|
||||
Usage:
|
||||
uv run python convert_vla_jepa_checkpoints.py
|
||||
|
||||
For each variant the script:
|
||||
1. Downloads the .pt checkpoint.
|
||||
2. Extracts the state dict.
|
||||
3. Instantiates VLAJEPAPolicy with the variant's confirmed config.
|
||||
4. Loads the state dict (strict=False — mismatches printed to stdout).
|
||||
5. push_to_hub → writes model.safetensors + config.json in LeRobot format.
|
||||
6. Adds the new repo to a shared HF collection.
|
||||
|
||||
Config sources
|
||||
--------------
|
||||
Numeric hyper-params : ginwind/VLA-JEPA/<variant>/config.json
|
||||
Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed
|
||||
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed
|
||||
Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file as save_safetensors
|
||||
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level settings
|
||||
# ---------------------------------------------------------------------------
|
||||
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
|
||||
TARGET_ORG = "maximellerbach"
|
||||
COLLECTION_TITLE = "VLA-JEPA"
|
||||
COLLECTION_DESCRIPTION = (
|
||||
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key mapping — mirrors todo_converter.py map_key() so both converters
|
||||
# produce identical safetensors layouts that match the LeRobot action_head code.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_source_key(key: str) -> str:
|
||||
return key[len("module.") :] if key.startswith("module.") else key
|
||||
|
||||
|
||||
def _map_checkpoint_key(raw_key: str) -> str | None:
|
||||
"""Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout."""
|
||||
key = _normalize_source_key(raw_key)
|
||||
|
||||
if key.startswith("qwen_vl_interface."):
|
||||
return "model.qwen." + key[len("qwen_vl_interface.") :]
|
||||
if key.startswith("vj_encoder."):
|
||||
return "model.video_encoder." + key[len("vj_encoder.") :]
|
||||
if key.startswith("vj_predictor."):
|
||||
return "model.video_predictor." + key[len("vj_predictor.") :]
|
||||
if key.startswith("action_model."):
|
||||
# LeRobot code uses the same sub-key names as the source checkpoint,
|
||||
# so only the top-level "model." prefix needs to be added.
|
||||
return "model." + key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
|
||||
"""Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict."""
|
||||
import json
|
||||
|
||||
stats_file = f"{subfolder}/dataset_statistics.json"
|
||||
try:
|
||||
local = api.hf_hub_download(source_repo_id, stats_file)
|
||||
data = json.loads(Path(local).read_text())
|
||||
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}}
|
||||
for robot_key in data:
|
||||
robot_data = data[robot_key]
|
||||
if isinstance(robot_data, dict) and "action" in robot_data:
|
||||
log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key)
|
||||
result = {"action": robot_data["action"]}
|
||||
if "state" in robot_data:
|
||||
result["observation.state"] = robot_data["state"]
|
||||
log.info(" Also loaded state stats.")
|
||||
return result
|
||||
log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _set_if_present(d: dict, key: str, value) -> None:
|
||||
if value is not None:
|
||||
d[key] = value
|
||||
|
||||
|
||||
def _deep_get(mapping: dict, path: tuple, default=None):
|
||||
current = mapping
|
||||
for key in path:
|
||||
if not isinstance(current, dict) or key not in current:
|
||||
return default
|
||||
current = current[key]
|
||||
return current
|
||||
|
||||
|
||||
def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict:
|
||||
"""Download config.yaml from the source HF repo for a given variant subfolder."""
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
log.warning("PyYAML not installed — cannot apply source config.yaml overrides.")
|
||||
return {}
|
||||
config_file = f"{subfolder}/config.yaml"
|
||||
try:
|
||||
local = api.hf_hub_download(source_repo_id, config_file)
|
||||
data = yaml.safe_load(Path(local).read_text()) or {}
|
||||
if isinstance(data, dict):
|
||||
log.info(" Loaded source config from %s", config_file)
|
||||
return data
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _apply_source_config(kwargs: dict, source_config: dict) -> None:
|
||||
"""Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic."""
|
||||
if not source_config:
|
||||
return
|
||||
|
||||
data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {})
|
||||
action_cfg = _deep_get(source_config, ("framework", "action_model"), {})
|
||||
diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {})
|
||||
video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {})
|
||||
trainer_cfg = source_config.get("trainer", {})
|
||||
|
||||
prompt_template = data_cfg.get("CoT_prompt")
|
||||
if prompt_template:
|
||||
kwargs["prompt_template"] = str(prompt_template)
|
||||
|
||||
action_horizon = action_cfg.get("action_horizon")
|
||||
if action_horizon is not None:
|
||||
kwargs["chunk_size"] = int(action_horizon)
|
||||
kwargs["n_action_steps"] = int(action_horizon)
|
||||
|
||||
_set_if_present(
|
||||
kwargs,
|
||||
"num_action_tokens_per_timestep",
|
||||
video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")),
|
||||
)
|
||||
_set_if_present(
|
||||
kwargs,
|
||||
"num_embodied_action_tokens_per_instruction",
|
||||
video_cfg.get(
|
||||
"num_embodied_action_tokens_per_instruction",
|
||||
action_cfg.get("num_embodied_action_tokens_per_instruction"),
|
||||
),
|
||||
)
|
||||
_set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps"))
|
||||
_set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token"))
|
||||
_set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token"))
|
||||
_set_if_present(
|
||||
kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size"))
|
||||
)
|
||||
_set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type"))
|
||||
_set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha"))
|
||||
_set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta"))
|
||||
_set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s"))
|
||||
_set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets"))
|
||||
_set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps"))
|
||||
_set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers"))
|
||||
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
|
||||
|
||||
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
|
||||
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
|
||||
_set_if_present(
|
||||
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
|
||||
)
|
||||
_set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio"))
|
||||
|
||||
_set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm"))
|
||||
learning_rate = trainer_cfg.get("learning_rate", {})
|
||||
if isinstance(learning_rate, dict):
|
||||
_set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model"))
|
||||
optimizer_cfg = trainer_cfg.get("optimizer", {})
|
||||
if isinstance(optimizer_cfg, dict):
|
||||
_set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps"))
|
||||
_set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay"))
|
||||
betas = optimizer_cfg.get("betas")
|
||||
if betas is not None:
|
||||
kwargs["optimizer_betas"] = tuple(betas)
|
||||
scheduler = trainer_cfg.get("scheduler", {})
|
||||
if isinstance(scheduler, dict):
|
||||
_set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps"))
|
||||
_set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr"))
|
||||
_set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps"))
|
||||
scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {})
|
||||
if isinstance(scheduler_kwargs, dict):
|
||||
_set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Architecture — identical across all 4 variants (from config.json)
|
||||
# ---------------------------------------------------------------------------
|
||||
_ARCH = {
|
||||
"qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B
|
||||
"chunk_size": 7,
|
||||
"n_action_steps": 7,
|
||||
"num_video_frames": 8,
|
||||
"jepa_tubelet_size": 2,
|
||||
"num_action_tokens_per_timestep": 8,
|
||||
"num_embodied_action_tokens_per_instruction": 32,
|
||||
"num_inference_timesteps": 4,
|
||||
"action_hidden_size": 1024,
|
||||
"action_model_type": "DiT-B",
|
||||
# Explicit dims matching DiT-B preset and ginwind checkpoint shape
|
||||
"action_num_heads": 12,
|
||||
"action_attention_head_dim": 64,
|
||||
"action_num_layers": 16,
|
||||
"action_dropout": 0.2,
|
||||
"repeated_diffusion_steps": 8,
|
||||
"action_noise_beta_alpha": 1.5,
|
||||
"action_noise_beta_beta": 1.0,
|
||||
"action_noise_s": 0.999,
|
||||
"action_num_timestep_buckets": 1000,
|
||||
# World model predictor (12 blocks, confirmed from checkpoint)
|
||||
"predictor_depth": 12,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image-key sets (confirmed sources in module docstring)
|
||||
# ---------------------------------------------------------------------------
|
||||
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
|
||||
_LIBERO_CAMS = [
|
||||
"observation.images.image", # agentview camera
|
||||
"observation.images.image2", # eye-in-hand camera
|
||||
]
|
||||
|
||||
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
|
||||
_DROID_CAMS = [
|
||||
"observation.images.exterior_1_left",
|
||||
"observation.images.exterior_2_left",
|
||||
]
|
||||
|
||||
# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch)
|
||||
_OXE_CAMS = [
|
||||
"observation.images.image",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_config(
|
||||
camera_keys: list[str],
|
||||
with_state: bool,
|
||||
enable_world_model: bool = True,
|
||||
source_config: dict | None = None,
|
||||
):
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
kwargs = dict(_ARCH)
|
||||
_apply_source_config(kwargs, source_config or {})
|
||||
|
||||
# Image resolution: prefer source config, fall back to 224
|
||||
data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {})
|
||||
raw_res = data_cfg.get("resolution_size")
|
||||
resolution_size = int(raw_res) if raw_res is not None else 224
|
||||
image_shape = (3, resolution_size, resolution_size)
|
||||
# Always set resize_images_to so the policy resizes env images to the training resolution,
|
||||
# regardless of what resolution the eval env renders at.
|
||||
kwargs["resize_images_to"] = (resolution_size, resolution_size)
|
||||
|
||||
# State / action dims: prefer source config
|
||||
action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {})
|
||||
state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8
|
||||
action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7
|
||||
|
||||
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys}
|
||||
if with_state:
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))
|
||||
|
||||
cfg = VLAJEPAConfig(
|
||||
input_features=input_features,
|
||||
output_features={
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
},
|
||||
enable_world_model=enable_world_model,
|
||||
binarize_gripper_action=True,
|
||||
clip_normalized_actions=True,
|
||||
**kwargs,
|
||||
)
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
# Maps each subfolder in SOURCE_REPO_ID to (camera_keys, with_state, enable_world_model, repo_suffix)
|
||||
VARIANTS: dict[str, tuple] = {
|
||||
"LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"),
|
||||
"Pretrain": (_DROID_CAMS, False, True, "Pretrain"),
|
||||
# SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so
|
||||
# disable the world model — only qwen + action_model weights are needed for inference.
|
||||
"SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"),
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
|
||||
if isinstance(ckpt, dict):
|
||||
sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("model")
|
||||
if sd is None:
|
||||
sd = ckpt
|
||||
else:
|
||||
sd = ckpt
|
||||
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
|
||||
def subfolder_of(pt_path: str) -> str | None:
|
||||
for part in Path(pt_path).parts:
|
||||
if part in VARIANTS:
|
||||
return part
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
api = HfApi()
|
||||
|
||||
log.info("Listing .pt files in %s …", SOURCE_REPO_ID)
|
||||
pt_files = [f for f in api.list_repo_files(SOURCE_REPO_ID) if f.endswith(".pt")]
|
||||
if not pt_files:
|
||||
log.error("No .pt files found.")
|
||||
return
|
||||
for f in pt_files:
|
||||
log.info(" %s", f)
|
||||
|
||||
# Create / reuse the collection once
|
||||
collection = api.create_collection(
|
||||
title=COLLECTION_TITLE,
|
||||
description=COLLECTION_DESCRIPTION,
|
||||
namespace=TARGET_ORG,
|
||||
exists_ok=True,
|
||||
)
|
||||
log.info("Collection: %s", collection.url)
|
||||
|
||||
for pt_filename in pt_files:
|
||||
log.info("\n=== %s ===", pt_filename)
|
||||
|
||||
subfolder = subfolder_of(pt_filename)
|
||||
if subfolder is None:
|
||||
log.warning(" No variant entry for '%s' — skipping.", pt_filename)
|
||||
continue
|
||||
|
||||
camera_keys, with_state, enable_world_model, repo_suffix = VARIANTS[subfolder]
|
||||
target_repo_id = f"{TARGET_ORG}/VLA-JEPA-{repo_suffix}"
|
||||
|
||||
log.info(
|
||||
" cameras=%d with_state=%s wm=%s → %s",
|
||||
len(camera_keys),
|
||||
with_state,
|
||||
enable_world_model,
|
||||
target_repo_id,
|
||||
)
|
||||
|
||||
# 1. Download
|
||||
local_pt = api.hf_hub_download(SOURCE_REPO_ID, pt_filename)
|
||||
log.info(" Downloaded → %s", local_pt)
|
||||
|
||||
# 2. Load checkpoint
|
||||
try:
|
||||
ckpt = torch.load(local_pt, map_location="cpu", mmap=True, weights_only=False) # nosec B614
|
||||
except TypeError:
|
||||
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
|
||||
|
||||
sd = extract_state_dict(ckpt)
|
||||
|
||||
# Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.)
|
||||
mapped_sd: dict[str, torch.Tensor] = {}
|
||||
skipped_keys: list[str] = []
|
||||
for raw_key, value in sd.items():
|
||||
target_key = _map_checkpoint_key(raw_key)
|
||||
if target_key is None:
|
||||
skipped_keys.append(raw_key)
|
||||
else:
|
||||
mapped_sd[target_key] = value
|
||||
log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys))
|
||||
if skipped_keys:
|
||||
log.info(" Skipped sample: %s", skipped_keys[:5])
|
||||
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
|
||||
|
||||
# 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers
|
||||
dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder)
|
||||
|
||||
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
|
||||
source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder)
|
||||
config = _build_config(camera_keys, with_state, enable_world_model, source_config)
|
||||
|
||||
# 5. Save everything to a temp dir and upload in one shot
|
||||
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
save_dir = Path(tmp)
|
||||
|
||||
log.info(" Saving model.safetensors …")
|
||||
save_safetensors(mapped_sd, save_dir / "model.safetensors")
|
||||
|
||||
config._save_pretrained(save_dir) # writes config.json via draccus
|
||||
|
||||
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
|
||||
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
|
||||
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
|
||||
|
||||
log.info(" Uploading …")
|
||||
commit_url = api.upload_folder(
|
||||
folder_path=save_dir,
|
||||
repo_id=target_repo_id,
|
||||
repo_type="model",
|
||||
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
|
||||
)
|
||||
log.info(" Uploaded → %s", commit_url)
|
||||
|
||||
# 6. Add to collection
|
||||
api.add_collection_item(
|
||||
collection_slug=collection.slug,
|
||||
item_id=target_repo_id,
|
||||
item_type="model",
|
||||
exists_ok=True,
|
||||
)
|
||||
log.info(" Added to collection.")
|
||||
|
||||
log.info("\nAll done. Collection: %s", collection.url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
607
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
607
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
@@ -0,0 +1,607 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = max(1, len(config.image_features))
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = []
|
||||
for i in range(b * v):
|
||||
video_pixels.append(
|
||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device)
|
||||
)
|
||||
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
"""
|
||||
Custom loading to enable opt reinit of action head
|
||||
when loading pretrained weights with mismatched action head shapes.
|
||||
"""
|
||||
if not model.config.reinit_action_head:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
mismatched: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
mismatched.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if mismatched:
|
||||
logging.warning(
|
||||
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
139
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
139
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] >= 7:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., 6] = (a[..., 6] >= 0.5).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes gripper dim (index 6) after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
|
||||
Only applied when action has >= 7 dimensions.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] >= 7:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(PreSnapGripperProcessorStep())
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(BinarizeGripperProcessorStep())
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
|
||||
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
def complementary_data(self, complementary_data):
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
103
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
103
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
404
src/lerobot/policies/vla_jepa/world_model.py
Normal file
404
src/lerobot/policies/vla_jepa/world_model.py
Normal file
@@ -0,0 +1,404 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -1 +0,0 @@
|
||||
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""
|
||||
@@ -1,22 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
|
||||
|
||||
def get_config(variant: str) -> Config:
|
||||
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
||||
if variant == "dummy":
|
||||
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
||||
if variant == "gemma_300m":
|
||||
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
if variant == "gemma_2b":
|
||||
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
@@ -1,300 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
|
||||
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
||||
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
if not self.gemma_expert.model.gradient_checkpointing:
|
||||
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
||||
self.gemma_expert.model.gradient_checkpointing = True
|
||||
use_gradient_checkpointing = True
|
||||
|
||||
# Debug gradient checkpointing status
|
||||
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
||||
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
||||
print(f"Model training mode: {self.training}")
|
||||
print(
|
||||
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
||||
)
|
||||
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
print(
|
||||
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
||||
)
|
||||
self._debug_gc_printed = True
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(
|
||||
layer.input_layernorm, hidden_states, adarms_cond[i]
|
||||
)
|
||||
gates.append(gate)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
self.paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
|
||||
return outputs_embeds
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
# Old code removed - now using compute_layer_complete function above
|
||||
|
||||
# final norm
|
||||
# Define final norm computation function for gradient checkpointing
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms,
|
||||
inputs_embeds,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
@@ -1,79 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def resize_with_pad_torch(
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
mode: str = "bilinear",
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
||||
|
||||
Args:
|
||||
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
||||
height: Target height
|
||||
width: Target width
|
||||
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
||||
|
||||
Returns:
|
||||
Resized and padded tensor with same shape format as input
|
||||
"""
|
||||
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
||||
if images.shape[-1] <= 4: # Assume channels-last format
|
||||
channels_last = True
|
||||
# Convert to channels-first for torch operations
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
||||
else:
|
||||
channels_last = False
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
|
||||
batch_size, channels, cur_height, cur_width = images.shape
|
||||
|
||||
# Calculate resize ratio
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
# Resize
|
||||
resized_images = F.interpolate(
|
||||
images,
|
||||
size=(resized_height, resized_width),
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
)
|
||||
|
||||
# Handle dtype-specific clipping
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
# Calculate padding
|
||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
||||
pad_h1 = pad_h0 + remainder_h
|
||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
mode="constant",
|
||||
value=constant_value,
|
||||
)
|
||||
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
@@ -1,471 +0,0 @@
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
class PI0Pytorch(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.pi05 = config.pi05
|
||||
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True] if self.pi05 else [False, False],
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
||||
|
||||
if self.pi05:
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
else:
|
||||
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if config.pytorch_compile_mode is not None:
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
# The upstream OpenPI module verifies a site-package Transformers patch here.
|
||||
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def is_gradient_checkpointing_enabled(self):
|
||||
"""Check if gradient checkpointing is enabled."""
|
||||
return self.gradient_checkpointing_enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
||||
|
||||
def _preprocess_observation(self, observation, *, train=True):
|
||||
"""Helper method to preprocess observation."""
|
||||
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
||||
return (
|
||||
list(observation.images.values()),
|
||||
list(observation.image_masks.values()),
|
||||
observation.tokenized_prompt,
|
||||
observation.tokenized_prompt_mask,
|
||||
observation.state,
|
||||
)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
def image_embed_func(img):
|
||||
return self.paligemma_with_expert.embed_image(img)
|
||||
|
||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
|
||||
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
# Get batch size from the first dimension of the concatenated tensors
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
if not self.pi05:
|
||||
if self.state_proj.weight.dtype == torch.float32:
|
||||
state = state.to(torch.float32)
|
||||
|
||||
# Embed state
|
||||
def state_proj_func(state):
|
||||
return self.state_proj(state)
|
||||
|
||||
state_emb = self._apply_checkpoint(state_proj_func, state)
|
||||
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.action_in_proj.out_features,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=timestep.device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=timestep.dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
def action_proj_func(noisy_actions):
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
if not self.pi05:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
# Apply MLP layers
|
||||
def mlp_func(action_time_emb):
|
||||
x = self.action_time_mlp_in(action_time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
return self.action_time_mlp_out(x)
|
||||
|
||||
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
||||
adarms_cond = None
|
||||
else:
|
||||
# time MLP (for adaRMS)
|
||||
def time_mlp_func(time_emb):
|
||||
x = self.time_mlp_in(time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
x = self.time_mlp_out(x)
|
||||
return F.silu(x)
|
||||
|
||||
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
||||
action_time_emb = action_emb
|
||||
adarms_cond = time_emb
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=True
|
||||
)
|
||||
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# Apply gradient checkpointing if enabled
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
# Apply gradient checkpointing to final action projection if enabled
|
||||
def action_out_proj_func(suffix_out):
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = observation.state.shape[0]
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=False
|
||||
)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step - use new tensor assignment instead of in-place operation
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
@@ -1,179 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Constants moved from model.py
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
def preprocess_observation_pytorch(
|
||||
observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
):
|
||||
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
||||
|
||||
This function avoids complex type annotations that can cause torch.compile issues.
|
||||
"""
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
|
||||
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
# Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad_torch(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = image.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
||||
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=image.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
image = torch.nn.functional.grid_sample(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=image.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
image = image * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=image.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = image.mean(dim=[1, 2, 3], keepdim=True)
|
||||
image = (image - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=image.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = image.mean(dim=-1, keepdim=True)
|
||||
image = gray + (image - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
|
||||
# Back to [-1, 1]
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
|
||||
else:
|
||||
out_masks[key] = observation.image_masks[key]
|
||||
|
||||
# Create a simple object with the required attributes instead of using the complex Observation class
|
||||
class SimpleProcessedObservation:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return SimpleProcessedObservation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions,
|
||||
compiled_model.sample_actions,
|
||||
sample_kwargs,
|
||||
"pi05.sample_actions",
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,56 +14,52 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -92,15 +88,6 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI05BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -109,163 +96,341 @@ class PI05BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = True
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
def instantiate_lerobot_pi05(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
|
||||
else:
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
|
||||
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi05():
|
||||
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
|
||||
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||
config = PI05BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
|
||||
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
|
||||
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
|
||||
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名;LeRobot 的
|
||||
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
|
||||
# loader,所以这里手动复用同一份 tensor,避免把权重别名差异误判成模型差异。
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi05,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
class PI05Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description (pi05 processor handles all text formatting)
|
||||
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks] * batch_size
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
|
||||
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create pi05-formatted prompts that include state information
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi05.train()
|
||||
else:
|
||||
lerobot_pi05.eval()
|
||||
original_pi05.eval()
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi05):
|
||||
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
lerobot_pi05.eval()
|
||||
original_pi05.eval()
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi05_original_vs_lerobot():
|
||||
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi05.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi05.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
def test_pi05_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
|
||||
def test_pi05_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi05_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi05_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
"state": torch.randn(1, config.max_state_dim, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,56 +14,51 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"action": {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -92,15 +87,6 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI0BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -109,156 +95,333 @@ class PI0BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = False
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
|
||||
else:
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
|
||||
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_pi0():
|
||||
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi0,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
|
||||
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi0.train()
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
else:
|
||||
lerobot_pi0.eval()
|
||||
original_pi0.eval()
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi0):
|
||||
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lerobot_pi0.eval()
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
|
||||
def test_pi0_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
|
||||
|
||||
def test_pi0_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi0_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi0_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Utilities shared by PI0/PI05 policy tests."""
|
||||
@@ -1,291 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
|
||||
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
||||
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenPIObservation:
|
||||
state: torch.Tensor
|
||||
images: dict[str, torch.Tensor]
|
||||
image_masks: dict[str, torch.Tensor]
|
||||
tokenized_prompt: torch.Tensor
|
||||
tokenized_prompt_mask: torch.Tensor
|
||||
token_ar_mask: torch.Tensor
|
||||
token_loss_mask: torch.Tensor
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def paligemma_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
|
||||
def clone_batch(batch: dict) -> dict:
|
||||
return {
|
||||
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
|
||||
}
|
||||
|
||||
|
||||
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
|
||||
if tensor.shape[-1] > target_dim:
|
||||
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
|
||||
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
|
||||
|
||||
|
||||
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
return (tensor - mean) / (std + 1e-8)
|
||||
|
||||
|
||||
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
|
||||
def openpi_model_state_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
state = batch[OBS_STATE].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
state = quantile_normalize(state, dataset_stats[OBS_STATE])
|
||||
else:
|
||||
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
|
||||
return pad_last_dim(state, action_dim)
|
||||
|
||||
|
||||
def openpi_model_actions_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
actions = batch[ACTION].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
actions = quantile_normalize(actions, dataset_stats[ACTION])
|
||||
else:
|
||||
actions = mean_std_normalize(actions, dataset_stats[ACTION])
|
||||
return pad_last_dim(actions, action_dim)
|
||||
|
||||
|
||||
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("The parity batch must include a task prompt.")
|
||||
if isinstance(tasks, str):
|
||||
return [tasks] * batch_size
|
||||
if len(tasks) == 1:
|
||||
return [tasks[0]] * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
|
||||
return list(tasks)
|
||||
|
||||
|
||||
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
|
||||
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
|
||||
|
||||
|
||||
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
|
||||
state_np = normalized_state.detach().cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
prompts = []
|
||||
for task, state in zip(tasks, discretized_states, strict=True):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, state))
|
||||
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
|
||||
return prompts
|
||||
|
||||
|
||||
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
|
||||
tokenized = paligemma_tokenizer()(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=max_token_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = tokenized["input_ids"].to(device)
|
||||
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
return tokens, masks
|
||||
|
||||
|
||||
def make_openpi_observation_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
max_token_len: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> OpenPIObservation:
|
||||
batch_size = batch[OBS_STATE].shape[0]
|
||||
device = batch[OBS_STATE].device
|
||||
state = openpi_model_state_from_raw(
|
||||
batch,
|
||||
action_dim=action_dim,
|
||||
dataset_stats=dataset_stats,
|
||||
pi05=pi05,
|
||||
)
|
||||
|
||||
tasks = _tasks_from_raw(batch, batch_size)
|
||||
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
|
||||
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
|
||||
|
||||
images = {
|
||||
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
|
||||
for key in IMAGE_KEYS
|
||||
}
|
||||
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
|
||||
|
||||
return OpenPIObservation(
|
||||
state=state,
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
tokenized_prompt=tokens,
|
||||
tokenized_prompt_mask=masks,
|
||||
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
|
||||
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
|
||||
)
|
||||
|
||||
|
||||
def assert_processor_inputs_match_lerobot(
|
||||
lerobot_policy,
|
||||
lerobot_batch: dict[str, torch.Tensor],
|
||||
openpi_observation: OpenPIObservation,
|
||||
*,
|
||||
compare_state: bool,
|
||||
):
|
||||
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
|
||||
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
|
||||
|
||||
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
|
||||
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt_mask,
|
||||
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
|
||||
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
|
||||
|
||||
for openpi_mask, lerobot_mask in zip(
|
||||
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
|
||||
):
|
||||
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
|
||||
|
||||
if compare_state:
|
||||
torch.testing.assert_close(
|
||||
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
|
||||
)
|
||||
|
||||
|
||||
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
|
||||
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
|
||||
return safetensors.torch.load_file(cache_dir / "model.safetensors")
|
||||
|
||||
|
||||
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
fixed_state_dict = dict(state_dict)
|
||||
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
|
||||
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
|
||||
return fixed_state_dict
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
|
||||
original_sample_noise = model.sample_noise
|
||||
original_sample_time = model.sample_time
|
||||
|
||||
def sample_noise(shape, device):
|
||||
if tuple(shape) != tuple(noise.shape):
|
||||
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
|
||||
return noise.to(device=device)
|
||||
|
||||
def sample_time(batch_size, device):
|
||||
if batch_size != time.shape[0]:
|
||||
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
|
||||
return time.to(device=device)
|
||||
|
||||
model.sample_noise = sample_noise
|
||||
model.sample_time = sample_time
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.sample_noise = original_sample_noise
|
||||
model.sample_time = original_sample_time
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
|
||||
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
|
||||
|
||||
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
|
||||
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
|
||||
otherwise compare two different image tensors rather than two model implementations. The context manager
|
||||
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
|
||||
|
||||
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
|
||||
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
|
||||
"""
|
||||
|
||||
original_preprocess_observation = openpi_policy._preprocess_observation
|
||||
|
||||
def preprocess_observation(observation, *, train=True):
|
||||
return original_preprocess_observation(observation, train=False)
|
||||
|
||||
openpi_policy._preprocess_observation = preprocess_observation
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
openpi_policy._preprocess_observation = original_preprocess_observation
|
||||
@@ -1,207 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, guard_failures
|
||||
from torch.profiler import ProfilerActivity
|
||||
|
||||
FORWARD_RTOL = 1e-5
|
||||
FORWARD_ATOL = 5e-2
|
||||
SAMPLE_RTOL = 1e-5
|
||||
SAMPLE_ATOL = 1e-2
|
||||
COMPILE_MODE = "max-autotune"
|
||||
STEADY_STATE_WARMUPS = 3
|
||||
STEADY_STATE_REPEATS = 3
|
||||
|
||||
|
||||
def make_compile_config(config_cls, *, compile_model):
|
||||
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
|
||||
|
||||
|
||||
def counter_total(name):
|
||||
return sum(counters.get(name, {}).values())
|
||||
|
||||
|
||||
def compile_snapshot():
|
||||
return {
|
||||
"graph_breaks": counter_total("graph_break"),
|
||||
"recompiles": counter_total("recompiles"),
|
||||
"recompile_limits": counter_total("recompile_limit"),
|
||||
"unique_graphs": counters["stats"].get("unique_graphs", 0),
|
||||
}
|
||||
|
||||
|
||||
def reset_compile_state():
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
guard_failures.clear()
|
||||
|
||||
|
||||
def clone_cuda_graph_output(output):
|
||||
if torch.is_tensor(output):
|
||||
return output.clone()
|
||||
if isinstance(output, tuple):
|
||||
return tuple(clone_cuda_graph_output(item) for item in output)
|
||||
if isinstance(output, list):
|
||||
return [clone_cuda_graph_output(item) for item in output]
|
||||
if isinstance(output, dict):
|
||||
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
|
||||
return output
|
||||
|
||||
|
||||
def run_model_step(fn: Callable, kwargs: dict):
|
||||
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return fn(**kwargs)
|
||||
|
||||
|
||||
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
explanation = torch._dynamo.explain(fn)(**kwargs)
|
||||
|
||||
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
|
||||
assert explanation.graph_break_count == 0, (
|
||||
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
|
||||
)
|
||||
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
|
||||
|
||||
print(
|
||||
f"{label} capture: graphs={explanation.graph_count}, "
|
||||
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
|
||||
f"guards={len(explanation.out_guards or [])}"
|
||||
)
|
||||
return explanation
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
|
||||
eager_forward = eager_model.forward(**forward_kwargs)
|
||||
compiled_forward = compiled_model.forward(**forward_kwargs)
|
||||
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
eager_actions = eager_model.sample_actions(**sample_kwargs)
|
||||
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
|
||||
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
|
||||
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
first_snapshot = compile_snapshot()
|
||||
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
second_snapshot = compile_snapshot()
|
||||
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
third_snapshot = compile_snapshot()
|
||||
|
||||
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
|
||||
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
|
||||
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
|
||||
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
|
||||
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
|
||||
)
|
||||
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
|
||||
)
|
||||
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
|
||||
|
||||
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
|
||||
run_warmups(eager_fn, kwargs)
|
||||
run_warmups(compiled_fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
eager_metrics = profile_callable(eager_fn, kwargs)
|
||||
compiled_metrics = profile_callable(compiled_fn, kwargs)
|
||||
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
|
||||
|
||||
print(
|
||||
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
|
||||
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
|
||||
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
|
||||
f"{compiled_metrics['host_wall_ms']:.3f}, "
|
||||
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
|
||||
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
|
||||
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
|
||||
f"{compiled_metrics['cuda_launch_count']}, "
|
||||
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
|
||||
f"{compiled_metrics['profiler_event_count']}, "
|
||||
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
|
||||
f"{compiled_metrics['peak_mem_mib']:.1f}"
|
||||
)
|
||||
|
||||
assert eager_metrics["cuda_event_ms"] > 0
|
||||
assert compiled_metrics["cuda_event_ms"] > 0
|
||||
assert eager_metrics["profiler_event_count"] > 0
|
||||
assert compiled_metrics["profiler_event_count"] > 0
|
||||
return eager_metrics, compiled_metrics
|
||||
|
||||
|
||||
def run_warmups(fn: Callable, kwargs: dict):
|
||||
for _ in range(STEADY_STATE_WARMUPS):
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_callable(fn: Callable, kwargs: dict):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
host_start = time.perf_counter()
|
||||
start_event.record()
|
||||
for _ in range(STEADY_STATE_REPEATS):
|
||||
run_model_step(fn, kwargs)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
|
||||
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
|
||||
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[ProfilerActivity.CPU],
|
||||
) as profiler:
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
key_averages = profiler.key_averages()
|
||||
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
|
||||
cuda_launch_count = sum(
|
||||
event.count
|
||||
for event in key_averages
|
||||
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
|
||||
)
|
||||
profiler_event_count = sum(event.count for event in key_averages)
|
||||
|
||||
return {
|
||||
"cuda_event_ms": cuda_event_ms,
|
||||
"host_wall_ms": host_wall_ms,
|
||||
"cpu_self_time_ms": cpu_self_time_ms,
|
||||
"cuda_launch_count": cuda_launch_count,
|
||||
"profiler_event_count": profiler_event_count,
|
||||
"peak_mem_mib": peak_mem_mib,
|
||||
}
|
||||
268
tests/policies/vla_jepa/conftest.py
Normal file
268
tests/policies/vla_jepa/conftest.py
Normal file
@@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python
|
||||
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BATCH_SIZE = 2
|
||||
ACTION_DIM = 3
|
||||
STATE_DIM = 4
|
||||
IMAGE_SIZE = 8
|
||||
ACTION_HORIZON = 4
|
||||
N_ACTION_STEPS = 2
|
||||
NUM_VIDEO_FRAMES = 3
|
||||
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
||||
|
||||
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def set_seed_all(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def make_config(
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> VLAJEPAConfig:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
},
|
||||
device="cpu",
|
||||
chunk_size=action_horizon,
|
||||
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
||||
action_dim=action_dim,
|
||||
state_dim=state_dim,
|
||||
num_video_frames=num_video_frames,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=1,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def make_train_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
||||
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def make_inference_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
state_dim: int = STATE_DIM,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLanguageLayer(nn.Module):
|
||||
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
|
||||
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
|
||||
return (hidden,)
|
||||
|
||||
|
||||
class _FakeLanguageModel(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
|
||||
self.layers[-1](hidden)
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
class _FakeQwenInnerModel(nn.Module):
|
||||
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.language_model = _FakeLanguageModel(hidden_size)
|
||||
|
||||
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
|
||||
class _FakeQwenBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
self.config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||
)
|
||||
self.model = _FakeQwenInnerModel(hidden_size)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self.config.hidden_size
|
||||
values = torch.arange(
|
||||
batch_size * seq_len * hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=torch.float32,
|
||||
).view(batch_size, seq_len, hidden_size)
|
||||
hidden = values / values.numel() + self.weight
|
||||
return SimpleNamespace(hidden_states=[hidden])
|
||||
|
||||
|
||||
class _FakeQwenInterface(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||
return action_tokens, action_token_ids, 2000
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, Tensor]:
|
||||
batch_size = len(images)
|
||||
del images, instructions, action_prompt, embodied_prompt
|
||||
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
token_ids = (
|
||||
[10]
|
||||
+ list(range(1000, 1000 + action_count))
|
||||
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||
+ [11]
|
||||
)
|
||||
return {
|
||||
"input_ids": torch.tensor(
|
||||
[token_ids] * batch_size,
|
||||
device=self.model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
|
||||
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
|
||||
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||
hidden_size = self.config.hidden_size
|
||||
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||
|
||||
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoModel,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoVideoProcessor,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||
)
|
||||
157
tests/policies/vla_jepa/test_action_head.py
Normal file
157
tests/policies/vla_jepa/test_action_head.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
from conftest import (
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
set_seed_all,
|
||||
) # noqa: E402
|
||||
|
||||
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
|
||||
VLAJEPAActionHead,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLAJEPAActionHead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4), # default test dims
|
||||
(7, 0, 16), # no proprioceptive state, production-like action space
|
||||
(6, 8, 8), # medium dims
|
||||
],
|
||||
)
|
||||
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
|
||||
assert t.shape == (200,)
|
||||
assert torch.isfinite(t).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
timesteps = torch.randint(0, 100, (2,))
|
||||
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
out_with = head._build_inputs(conditioning, actions, state, timesteps)
|
||||
out_none = head._build_inputs(conditioning, actions, None, timesteps)
|
||||
|
||||
assert out_with.ndim == 3 and out_none.ndim == 3
|
||||
if state_dim > 0:
|
||||
assert out_with.shape[1] > out_none.shape[1]
|
||||
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
def test_action_head_forward_gradient_flows() -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
pred = head.predict_action(conditioning, state)
|
||||
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||
assert torch.isfinite(pred).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# action_is_pad masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_action_head_loss_fully_padded_is_zero() -> None:
|
||||
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss = head.forward(conditioning, actions, state, action_is_pad)
|
||||
assert loss.item() == 0.0
|
||||
|
||||
|
||||
def test_action_head_loss_none_matches_no_padding() -> None:
|
||||
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
set_seed_all(0)
|
||||
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
|
||||
|
||||
set_seed_all(0)
|
||||
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
|
||||
|
||||
assert torch.isclose(loss_none, loss_zeros)
|
||||
57
tests/policies/vla_jepa/test_configuration.py
Normal file
57
tests/policies/vla_jepa/test_configuration.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def test_delta_indices() -> None:
|
||||
config = make_config()
|
||||
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
||||
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
||||
|
||||
|
||||
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
||||
with pytest.raises(ValueError, match="n_action_steps"):
|
||||
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
|
||||
|
||||
|
||||
def test_too_few_video_frames_raises() -> None:
|
||||
with pytest.raises(ValueError, match="video_horizon"):
|
||||
VLAJEPAConfig(
|
||||
chunk_size=16,
|
||||
n_action_steps=16,
|
||||
num_video_frames=2,
|
||||
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
||||
)
|
||||
|
||||
|
||||
def test_validate_features_no_image_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_no_action_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
},
|
||||
output_features={},
|
||||
)
|
||||
with pytest.raises(ValueError, match="action output feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_sets_action_dim_from_feature() -> None:
|
||||
config = make_config(action_dim=6, state_dim=10)
|
||||
assert config.action_dim == 6
|
||||
assert config.state_dim == 10
|
||||
473
tests/policies/vla_jepa/test_vla_jepa.py
Normal file
473
tests/policies/vla_jepa/test_vla_jepa.py
Normal file
@@ -0,0 +1,473 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||
)
|
||||
|
||||
from conftest import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
EXPECTED_ACTION_CHUNK_SHAPE,
|
||||
EXPECTED_SELECT_ACTION_SHAPE,
|
||||
N_ACTION_STEPS,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
make_inference_batch,
|
||||
make_train_batch,
|
||||
set_seed_all,
|
||||
)
|
||||
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
|
||||
from lerobot.utils.constants import ACTION # noqa: E402
|
||||
|
||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||
|
||||
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core training / inference tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
|
||||
batch = make_train_batch()
|
||||
batch_before = deepcopy(batch)
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
assert logs["action_loss"] > 0
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
|
||||
# Batch must not be mutated.
|
||||
assert set(batch) == set(batch_before)
|
||||
for key, value in batch.items():
|
||||
if isinstance(value, Tensor):
|
||||
assert torch.equal(value, batch_before[key])
|
||||
else:
|
||||
assert value == batch_before[key]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_training_forward_various_dims(
|
||||
patch_vla_jepa_external_models: None,
|
||||
action_dim: int,
|
||||
state_dim: int,
|
||||
action_horizon: int,
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.train()
|
||||
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
loss, _ = policy.forward(batch)
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert chunk.device.type == "cpu"
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
a1 = policy.select_action(batch)
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
|
||||
def test_action_generation_various_dims(
|
||||
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.eval()
|
||||
batch = make_inference_batch(state_dim=state_dim)
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert chunk.shape[-1] == action_dim
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
set_seed_all(123)
|
||||
actions_1 = policy.predict_action_chunk(batch)
|
||||
set_seed_all(123)
|
||||
actions_2 = policy.predict_action_chunk(batch)
|
||||
|
||||
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
for seed in [0, 42, 123]:
|
||||
set_seed_all(seed)
|
||||
chunk = policy.predict_action_chunk(make_inference_batch())
|
||||
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action queue behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
# First call fills the queue (n_action_steps items) and pops one.
|
||||
a1 = policy.select_action(batch)
|
||||
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
|
||||
|
||||
# Second call pops from the existing queue without calling predict_action_chunk.
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
policy.select_action(make_inference_batch())
|
||||
assert len(policy._queues[ACTION]) > 0
|
||||
|
||||
policy.reset()
|
||||
assert len(policy._queues[ACTION]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||
from PIL import Image
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
examples = policy._prepare_model_inputs(make_train_batch())
|
||||
|
||||
assert len(examples) == BATCH_SIZE
|
||||
for ex in examples:
|
||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||
assert ex["state"].shape == (1, STATE_DIM)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||
assert "action" not in ex
|
||||
assert "image" in ex and "video" in ex and "lang" in ex
|
||||
|
||||
|
||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch["task"]
|
||||
examples = policy._prepare_model_inputs(batch)
|
||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
batch["task"] = "open the drawer"
|
||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch[OBS_STATE]
|
||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretrained checkpoint
|
||||
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||
return batch
|
||||
|
||||
|
||||
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||
return batch
|
||||
|
||||
|
||||
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
|
||||
|
||||
# Each tuple: (repo_id, enable_world_model)
|
||||
_HUB_VARIANTS = [
|
||||
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||
]
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
assert policy.config.enable_world_model == enable_world_model
|
||||
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
policy.train()
|
||||
|
||||
batch = _make_hub_train_batch(policy)
|
||||
loss, logs = policy.forward(batch)
|
||||
assert torch.isfinite(loss)
|
||||
assert "action_loss" in logs
|
||||
if enable_world_model:
|
||||
assert "wm_loss" in logs
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
qwen_params = list(policy.model.qwen.parameters())
|
||||
assert all(not p.requires_grad for p in qwen_params)
|
||||
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
assert policy.model.video_encoder is None
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_libero_inference_shape() -> None:
|
||||
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||
policy.eval()
|
||||
batch = _make_hub_inference_batch(policy)
|
||||
action = policy.select_action(batch)
|
||||
assert action.shape[-1] == policy.config.action_dim
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postprocessor unnormalization tests
|
||||
#
|
||||
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
|
||||
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
|
||||
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
return {
|
||||
ACTION: {
|
||||
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
|
||||
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
|
||||
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
|
||||
rng = np.random.default_rng(7)
|
||||
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
actions_tensor = torch.from_numpy(actions_np)
|
||||
transition = policy_action_to_transition(actions_tensor)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
|
||||
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
|
||||
# Deliberately out-of-range inputs
|
||||
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
|
||||
clipped = np.clip(actions_np, -1.0, 1.0)
|
||||
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
clip_step = ClipActionsProcessorStep()
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
transition = policy_action_to_transition(torch.from_numpy(actions_np))
|
||||
transition = clip_step(transition)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_applied_after_predict_action_chunk(
|
||||
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
|
||||
|
||||
Verifies the split: predict_action_chunk returns normalized actions, and calling the
|
||||
postprocessor on them produces the correctly unnormalized result.
|
||||
"""
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||
|
||||
cfg = make_config()
|
||||
cfg.clip_normalized_actions = False
|
||||
cfg.binarize_gripper_action = False
|
||||
policy = VLAJEPAPolicy(cfg)
|
||||
policy.eval()
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||
|
||||
batch = make_inference_batch()
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
|
||||
# predict_action_chunk returns raw (normalized) actions
|
||||
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
|
||||
"predict_action_chunk should return raw actions without unnormalization applied."
|
||||
)
|
||||
|
||||
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
|
||||
unnormed = postprocessor(chunk)
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
|
||||
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
|
||||
60
tests/policies/vla_jepa/test_world_model.py
Normal file
60
tests/policies/vla_jepa/test_world_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.world_model import (
|
||||
ActionConditionedVideoPredictor,
|
||||
)
|
||||
|
||||
_ACTION_EMBED_DIM = 8
|
||||
|
||||
|
||||
def _make_predictor(
|
||||
embed_dim: int = 8,
|
||||
action_embed_dim: int = _ACTION_EMBED_DIM,
|
||||
predictor_embed_dim: int = 24,
|
||||
num_action_tokens: int = 2,
|
||||
tokens_per_frame: int = 1,
|
||||
) -> ActionConditionedVideoPredictor:
|
||||
return ActionConditionedVideoPredictor(
|
||||
num_frames=1,
|
||||
img_size=(1, tokens_per_frame),
|
||||
patch_size=1,
|
||||
tubelet_size=1,
|
||||
embed_dim=embed_dim,
|
||||
action_embed_dim=action_embed_dim,
|
||||
predictor_embed_dim=predictor_embed_dim,
|
||||
depth=1,
|
||||
num_heads=2,
|
||||
mlp_ratio=2.0,
|
||||
num_action_tokens_per_step=num_action_tokens,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,num_steps,tokens_per_frame,embed_dim",
|
||||
[
|
||||
(1, 2, 1, 8),
|
||||
(2, 3, 4, 8),
|
||||
(4, 5, 2, 16),
|
||||
],
|
||||
)
|
||||
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||
predictor = _make_predictor(
|
||||
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
|
||||
)
|
||||
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
|
||||
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
|
||||
out = predictor(frame_tokens, action_tokens)
|
||||
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_predictor_step_mismatch_raises() -> None:
|
||||
predictor = _make_predictor(tokens_per_frame=4)
|
||||
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
|
||||
with pytest.raises(RuntimeError):
|
||||
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch
|
||||
@@ -1,155 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI05PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
|
||||
|
||||
_preprocess_images = PI05Policy._preprocess_images
|
||||
|
||||
def __init__(self, config: PI05Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi05_config() -> PI05Config:
|
||||
config = PI05Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi05_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi05_config()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI05PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
@@ -1,156 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI0PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
|
||||
|
||||
_preprocess_images = PI0Policy._preprocess_images
|
||||
prepare_state = PI0Policy.prepare_state
|
||||
|
||||
def __init__(self, config: PI0Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi0_config() -> PI0Config:
|
||||
config = PI0Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi0_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi0_config()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI0PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
33
uv.lock
generated
33
uv.lock
generated
@@ -3039,6 +3039,11 @@ video-benchmark = [
|
||||
viz = [
|
||||
{ name = "rerun-sdk" },
|
||||
]
|
||||
vla-jepa = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
wallx = [
|
||||
{ name = "peft" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
@@ -3107,6 +3112,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
@@ -3154,6 +3160,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
||||
@@ -3177,12 +3184,14 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
||||
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
|
||||
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
|
||||
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
|
||||
@@ -3203,7 +3212,7 @@ requires-dist = [
|
||||
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
|
||||
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
|
||||
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
||||
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
||||
@@ -3244,7 +3253,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
@@ -4592,7 +4601,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "placo"
|
||||
version = "0.9.15"
|
||||
version = "0.9.16"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cmeel" },
|
||||
@@ -4602,16 +4611,16 @@ dependencies = [
|
||||
{ name = "pin" },
|
||||
{ name = "rhoban-cmeel-jsoncpp" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user