Compare commits

..

7 Commits

Author SHA1 Message Date
Khalil Meftah
badbb69fb7 fix(rewards/topreward): Remove TOPReward README symlink, use docs page only 2026-05-27 18:50:14 +01:00
Khalil Meftah
23036baf22 fix(rewards/topreward): dd README symlink for TOPReward docs 2026-05-27 18:33:00 +01:00
Haoming Song
3b5b94dbd6 optmize topreward input processing (#3660) 2026-05-25 16:07:45 +02:00
Cole
616663cd9f fix(rewards/topreward): fix pyproject extra typo and simplify processor (#3653)
Add lerobot[topreward] extra to all in
pyproject.toml, drop the redundant labels arg in scoring, and
collapse the dead-branch shape check in the encoder processor.
2026-05-23 00:27:09 +02:00
Khalil Meftah
5cfca59ec7 fix(rewards/topreward): add missing input keys mm_token_type_ids 2026-05-21 11:05:02 +02:00
Khalil Meftah
f6ecb7b955 refactor(rewards): clean up TOPReward processor/model 2026-05-20 17:39:21 +02:00
Khalil Meftah
70ad322676 feat(rewards): add TOPReward reward model 2026-05-19 18:00:18 +02:00
70 changed files with 898 additions and 19608 deletions

View File

@@ -59,8 +59,6 @@
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: eo1
title: EO-1
- local: groot

View File

@@ -79,13 +79,17 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/act_policy \
--robot.type=so101_follower \
lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_robot \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true \
--task="Your task description" \ # can be skipped for ACT
--duration=60
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=${HF_USER}/act_policy
```

View File

@@ -105,12 +105,10 @@ These results demonstrate GR00T's strong generalization capabilities across dive
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
```bash
lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
lerobot-record \
--robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
@@ -121,12 +119,14 @@ lerobot-rollout\
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--duration=600
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
```
## License

View File

@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
port="/dev/tty.usbmodem58760431541",
id="my_red_robot_arm",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
)
robot = SO101Follower(robot_config)
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoption id="Command">
```bash
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem5AB90687491 \
--robot.id=my_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem5AB90689011 \
--teleop.id=my_leader_arm \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=my_awesome_leader_arm \
--display_data=true
```
</hfoption>
@@ -122,48 +122,34 @@ lerobot-teleoperate \
<!-- prettier-ignore-start -->
```python
import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
camera_config = {
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
}
robot_config = KochFollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="my_red_robot_arm",
cameras=camera_config
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
teleop_config = KochLeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
)
init_rerun(session_name="teleoperation")
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
robot = KochFollower(robot_config)
teleop_device = KochLeader(teleop_config)
robot.connect()
teleop_device.connect()
TARGET_HZ = 30
TIME_PER_FRAME = 1.0 / TARGET_HZ
while True:
start_time = time.perf_counter()
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
if sleep_time > 0:
time.sleep(sleep_time)
```
<!-- prettier-ignore-end -->
@@ -216,11 +202,10 @@ lerobot-record \
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
@@ -233,56 +218,71 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
def main():
# Create robot configuration
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
# Create robot configuration
robot_config = SO100FollowerConfig(
id="my_awesome_follower_arm",
cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
},
port="/dev/tty.usbmodem58760434471",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
teleop_config = SO100LeaderConfig(
id="my_awesome_leader_arm",
port="/dev/tty.usbmodem585A0077581",
)
# Initialize the robot and teleoperator
robot = SO101Follower(robot_config)
teleop = SO101Leader(teleop_config)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop = SO100Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
@@ -291,50 +291,26 @@ def main():
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
episode_idx += 1
dataset.save_episode()
episode_idx += 1
# finalize dataset
log_say("Finalizing dataset...")
dataset.finalize()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
if __name__ == "__main__":
main()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
```
<!-- prettier-ignore-end -->
@@ -372,7 +348,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording.
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
@@ -446,7 +422,7 @@ from lerobot.utils.utils import log_say
episode_idx = 0
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
robot = SO100Follower(robot_config)
robot.connect()
@@ -514,83 +490,6 @@ Additionally you can provide extra `tags` or specify a `license` for your model
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Train using Hugging Face Jobs
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
To run the training use this command:
<hfoptions id="train_with_hf_jobs">
<hfoption id="Command">
```bash
hf jobs run \
--flavor a10g-small \
--timeout 4h \
--secrets HF_TOKEN \
huggingface/lerobot-gpu:latest \
-- \
python -m lerobot.scripts.lerobot_train \
--dataset.repo_id=username/dataset \
--policy.type=act \
--steps=5000 \
--batch_size=16 \
--policy.device=cuda \
--policy.repo_id=username/your_policy \
--log_freq=100
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from huggingface_hub import run_job, get_token
run_name = "act_so101_hf_jobs"
dataset_id = "username/dataset"
user_hub_id = "username"
command_args = [
"python", "-m", "lerobot.scripts.lerobot_train",
"--dataset.repo_id", dataset_id,
"--policy.type", "act",
"--steps", "5000",
"--batch_size", "16",
"--num_workers", "4",
"--policy.device", "cuda",
"--log_freq", "100",
"--save_freq", "1000",
"--save_checkpoint", "true",
"--wandb.enable", "false",
"--policy.repo_id", f"{user_hub_id}/{run_name}"
]
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
job_info = run_job(
image="huggingface/lerobot-gpu:latest",
command=command_args,
flavor="a10g-small",
timeout="4h",
secrets={"HF_TOKEN": get_token()}
)
print("\n🚀 Job successfully launched!")
print(f"🔹 Job ID: {job_info.id}")
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
For longer training sessions increase the timeout.
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:

View File

@@ -1,433 +0,0 @@
# MolmoAct2 Policy
MolmoAct2 is the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
training, evaluation, checkpointing, and dataset interfaces for easier use with
LeRobot datasets.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## Installation Requirements
Install LeRobot with the MolmoAct2 optional dependencies:
```bash
pip install -e ".[molmoact2]"
```
To run the models in this repository, you need an NVIDIA GPU. The measurements
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
`gradient_checkpointing=true` and include the forward pass, backward pass,
gradient clipping, optimizer step, and optimizer state allocation. Values are
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
dataloader workers, CUDA context, and fragmentation.
Multi-GPU training through `accelerate` increases throughput and global batch
size, but this LeRobot port does not currently expose the original MolmoAct2
`fsdp_devices` model-parallel training path. The current training script has
not been tested for multi-node training.
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
The repo has been tested with Ubuntu 22.04.
## Usage
To use MolmoAct2 in a LeRobot training config, set:
```python
policy.type=molmoact2
```
## Training
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
the same LeRobot training loop, dataset transforms, checkpoint saving, and
logging. The difference is only how the initial policy weights and processor
state are loaded.
### Training With Original MolmoAct2 Weight
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
the original HF model files, then build its own policy processor from the
dataset metadata and the policy options below.
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
augmentation, and LeRobot's checkpointing/logging path.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.setup_type="single franka robotic arm in libero" \
--policy.control_mode="delta end-effector pose" \
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--policy.freeze_embedding=true \
--policy.normalize_gripper=false \
--policy.enable_knowledge_insulation=false \
--policy.push_to_hub=false \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Training With LeRobot MolmoAct2 Weight
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
restores the saved LeRobot policy config, model weights, processor, and
normalization statistics. You can still override training-time options such as
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.path=/path/to/pretrained_model \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Common Practices
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
or a real-world dataset with less than 200 demonstrations, a global batch size of
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
cases, we intentionally keep the action expert fully trainable, which we found
to be crucial for model performance. For larger fine-tuning datasets, larger
global batch sizes and full fine-tuning are usually preferred.
### Common Policy Options
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
Use this for released MolmoAct2 weights.
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
created by LeRobot training.
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
`both`. `both` trains the flow-matching action expert and the discrete
action-token loss.
- `policy.train_action_expert_only`: trains only parameters whose names contain
`action_expert`. It requires `policy.action_mode=continuous`.
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
expert linear layers. When `policy.enable_lora_action_expert=false`, the
action expert base weights remain fully trainable while the VLM is trained
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
action expert is also adapter-tuned instead of fully fine-tuned.
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
context K/V states before the action loss. The default is `false`.
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
`10`. This LeRobot port overrides the loaded checkpoint's
`max_action_horizon` with this value.
- `policy.n_action_steps`: number of actions consumed from each predicted
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
- `policy.setup_type`: text inserted into the prompt to describe the robot and
scene, e.g. `single franka robotic arm in libero`. More examples are listed
in the `metadata_by_tag` entries of
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
- `policy.control_mode`: text inserted into the prompt to describe the action
space, e.g. `delta end-effector pose` or `absolute joint pose`.
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
processor.
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
example during training. We use `8` for fine-tuning.
- `policy.num_inference_steps`: optional override for continuous action
generation steps at inference time.
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
to reduce activation memory.
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
- `policy.normalize_gripper`: controls whether gripper dimensions are included
in state/action quantile normalization. The default is `false`.
- `policy.normalize_language`: normalizes task strings before prompt
construction. The default is `true`.
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
Released checkpoints use `policy.expected_max_action_dim=32`.
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
infer it from images, state dimension, action dimension, action horizon, and
discrete-action mode.
### Learning Rates
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
fine-tuning experiments.
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
`policy.optimizer_vit_lr=5e-6` for the vision tower,
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
`policy.enable_lora_action_expert=false`, so the action expert is still fully
fine-tuned with `policy.optimizer_action_expert_lr`. If
`policy.enable_lora_action_expert=true`, the action expert is trained through
LoRA adapters instead.
- Action-expert-only fine-tuning trains only the action expert and uses
`policy.optimizer_action_expert_lr=5e-5`.
You can override the full fine-tuning and action-expert learning rates with
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
### Dataset Quantile Statistics
MolmoAct2 defaults to quantile normalization for state and action features. If
your dataset has not been converted with quantile statistics, you can add them
with:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset
```
Alternatively, train MolmoAct2 with mean/std normalization:
```bash
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
```
## Evaluation
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
fixed and use `policy.per_episode_seed=true`.
**Important:** We found that `num_steps_wait=10` does not reliably let the
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
results reported here use `num_steps_wait=50`.
### Evaluation With LeRobot MolmoAct2 Weight
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
normalization statistics are restored together with the model.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
--policy.inference_action_mode=continuous \
--policy.model_dtype=bfloat16 \
--policy.use_amp=true \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
### Evaluation With Original MolmoAct2 Weight
You can evaluate a released Hugging Face checkpoint directly without first
converting it to a LeRobot checkpoint. In this case, set
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
normalization statistics, action horizon, prompt metadata, and image-key order
from the checkpoint's `norm_stats.json`.
To fully replicate the MolmoAct2 paper results with released Hugging Face
checkpoints, we recommend using the v0.5.1-pinned
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
branch. That branch matches the original evaluation settings used for the
reported numbers.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.norm_tag=libero \
--policy.inference_action_mode=continuous \
--policy.model_dtype=float32 \
--policy.use_amp=false \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_goal \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
full LIBERO suite. The same command works for other released MolmoAct2
checkpoints as long as the requested `policy.norm_tag` exists in that
checkpoint's `norm_stats.json`.
### Common Evaluation Options
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
flow-matching inference or `discrete` for action-token inference. It must be
compatible with the training-time `policy.action_mode` saved in the
checkpoint.
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
saved by LeRobot.
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
image-key order, and action horizon from the original checkpoint's
`norm_stats.json`. It is required for direct original-HF checkpoint
evaluation.
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
- `policy.use_amp`: runs the policy forward under autocast during eval. For
`model_dtype=bfloat16`, keep this enabled.
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
graph path for faster repeated continuous-action rollout.
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
action generation deterministic per episode for replication.
- `env.task`: comma-separated LIBERO suites or a single suite. Use
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
by the policy processor.
## Performance Results
### LIBERO Benchmark Results
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
compare and test its LeRobot implementation, we fine-tuned
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
results.
The LeRobot fine-tuned checkpoint reported here is available at
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
and was trained on
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
| -------------- | ---------------------: | -----------------: |
| LIBERO Spatial | 98.4% | 97.8% |
| LIBERO Object | 100.0% | 100.0% |
| LIBERO Goal | 98.0% | 97.8% |
| LIBERO 10 | 96.6% | 93.2% |
| Average | 98.25% | 97.20% |
These results demonstrate MolmoAct2's strong performance across diverse robotic
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
evaluation section.
## Differences From the Original Implementation
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
differences from the original training repository are:
- The original paper training stack loads the model in fp32 and trains under
mixed precision. This LeRobot port usually loads the checkpoint directly in
`policy.model_dtype=bfloat16` for lower memory use.
- The original repository uses its own FSDP/model-parallel training path. The
LeRobot port uses the standard LeRobot/Accelerate training path and has not
been tested for multi-node training.
- The original repository supports sequence packing. The LeRobot port trains on
one LeRobot sample per item and pads to an inferred fixed sequence budget.
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
dataset transforms, image augmentation, and Weights & Biases logging
conventions.
- The original training path supports mixed action horizons by padding to
`max_action_horizon` and masking padded horizon slots in the action expert
self-attention. This is useful when training across datasets with different
control frequencies. The LeRobot port currently targets single-dataset
fine-tuning, so `policy.chunk_size` overrides the checkpoint
`max_action_horizon` and horizon masking is not implemented yet. Support for
this mixed-horizon path is planned.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).

View File

@@ -1,39 +0,0 @@
# MolmoAct2
This repository contains the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
training, evaluation, checkpointing, and dataset compatibility.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## LIBERO Evaluation
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
scene stabilize and can degrade measured success. All LIBERO evaluation results
reported for this LeRobot implementation use `num_steps_wait=50`.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).

View File

@@ -1,232 +0,0 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.

View File

@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing:
```bash
lerobot-rollout \
--strategy.type=base \
lerobot-record \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
# <- RTC optional, use when running on low power hardware \
# --inference.type=rtc \
# --inference.rtc.execution_horizon=10 \
# --inference.rtc.max_guidance_weight=10.0 \
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_red_leader_arm \
# --display_data=true #optional use if you want to see the camera stream \
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
```

View File

@@ -15,12 +15,10 @@
# limitations under the License.
"""
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage:
python examples/dataset/create_progress_videos.py \
@@ -58,26 +56,22 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", progress_file],
allow_patterns=["meta/**", "sarm_progress.parquet"],
ignore_patterns=["*.mp4"],
)
)
@@ -221,28 +215,25 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path
def load_progress_data(
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / progress_file
parquet_path = local_path / "sarm_progress.parquet"
if not parquet_path.exists():
logging.warning("%s not found", progress_file)
logging.warning("sarm_progress.parquet not found")
return None
df = pd.read_parquet(parquet_path)
logging.info(" %s columns: %s", progress_file, list(df.columns))
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
logging.warning("No sarm_progress rows for episode %d", episode)
return None
episode_df = episode_df.sort_values("frame_index")
@@ -585,7 +576,6 @@ def process_dataset(
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
@@ -595,8 +585,6 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Path to the final output file, or None on failure.
@@ -604,7 +592,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode, progress_file)
local_path = download_episode_metadata(repo_id, episode)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -612,9 +600,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode, progress_file)
progress_data = load_progress_data(local_path, episode)
if progress_data is None:
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
logging.error("Could not load sarm_progress data. Skipping overlay.")
return None
logging.info(" Progress frames: %d", len(progress_data))
@@ -639,7 +627,7 @@ def process_dataset(
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
@@ -670,15 +658,6 @@ def main() -> None:
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -691,7 +670,6 @@ def main() -> None:
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
progress_file=args.progress_file,
)
if result:

View File

@@ -138,9 +138,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
# Common
av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
@@ -198,7 +196,6 @@ wallx = [
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
@@ -216,7 +213,6 @@ topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -277,12 +273,10 @@ all = [
"lerobot[multi_task_dit]",
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -409,11 +403,8 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"arange",
"is_compileable",
"ROBOTIS",
"OT_VALUE",
"VanderBilt"
"OT_VALUE"
]
# TODO: Uncomment when ready to use

View File

@@ -255,7 +255,8 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
del config_data[field]
else:
del config_data[field]
modified = True
if not modified:
@@ -310,13 +311,7 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
if config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = draccus.parse(
config_class=argtype,
config_path=config_path_cli or config_path,
args=cli_args,
)
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
response = fn(cfg, *args, **kwargs)
return response

View File

@@ -250,14 +250,7 @@ class DatasetWriter:
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
stacked_values = np.stack(episode_buffer[key])
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
stacked_values = stacked_values.reshape(episode_length)
episode_buffer[key] = stacked_values
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()

View File

@@ -17,13 +17,11 @@ import contextlib
import glob
import importlib
import logging
import os
import queue
import shutil
import tempfile
import threading
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
from fractions import Fraction
from pathlib import Path
@@ -193,70 +191,15 @@ def decode_video_frames_pyav(
return closest_frames
DEFAULT_DECODER_CACHE_SIZE = 100
"""Default LRU capacity for :class:`VideoDecoderCache`.
Sized to comfortably hold a small rolling window of episodes worth of decoders
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
to the constructor (``None`` restores the legacy unbounded behaviour).
"""
def _default_max_cache_size() -> int | None:
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
if raw is None:
return DEFAULT_DECODER_CACHE_SIZE
raw = raw.strip().lower()
if raw in ("", "none", "unbounded", "-1"):
return None
try:
value = int(raw)
except ValueError as e:
raise ValueError(
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
) from e
if value <= 0:
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
return value
class VideoDecoderCache:
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
backing it. When the cache is full and a new path is requested, the
least-recently-used entry is evicted and its file handle is closed. This
bounds host-RAM growth when iterating over datasets with many distinct
video files (otherwise each ``DataLoader`` worker pins every decoder it has
ever opened until the process exits).
Args:
max_size: Maximum number of decoders to retain. ``None`` disables
eviction and restores legacy unbounded behaviour. Defaults to the
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
:data:`DEFAULT_DECODER_CACHE_SIZE`.
"""
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
raise ValueError(f"max_size must be positive or None; got {max_size}")
self.max_size: int | None = max_size # type: ignore[assignment]
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
self._lock = Lock()
def __contains__(self, video_path: object) -> bool:
with self._lock:
return str(video_path) in self._cache
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
"""Get a cached decoder or create a new one."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
@@ -268,36 +211,22 @@ class VideoDecoderCache:
video_path = str(video_path)
with self._lock:
entry = self._cache.get(video_path)
if entry is not None:
self._cache.move_to_end(video_path)
return entry[0]
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
# Evict LRU entries until we are back under the cap. We close
# evicted file handles immediately; the associated ``VideoDecoder``
# is released to the GC when its last reference goes away.
if self.max_size is not None:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
with contextlib.suppress(Exception):
evicted_handle.close()
return decoder
return self._cache[video_path][0]
def clear(self):
"""Clear the cache and close all file handles."""
"""Clear the cache and close file handles."""
with self._lock:
for _, file_handle in self._cache.values():
with contextlib.suppress(Exception):
file_handle.close()
file_handle.close()
self._cache.clear()
def size(self) -> int:

View File

@@ -18,25 +18,12 @@ from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import require_package
from lerobot.utils.import_utils import _placo_available, require_package
_placo_runtime_error: ImportError | None = None
if TYPE_CHECKING:
if TYPE_CHECKING or _placo_available:
import placo # type: ignore[import-not-found]
else:
try:
import placo # type: ignore[import-not-found]
except ImportError as _placo_import_err:
placo = None
_placo_runtime_error = _placo_import_err
def _raise_if_placo_unusable() -> None:
if placo is None and _placo_runtime_error is not None:
raise ImportError(
f"placo is installed but failed to import: {_placo_runtime_error!s}"
) from _placo_runtime_error
placo = None
class RobotKinematics:
@@ -57,7 +44,6 @@ class RobotKinematics:
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
require_package("placo", extra="placo-dep")
_raise_if_placo_unusable()
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)

View File

@@ -43,7 +43,6 @@ from .tables import (
CAN_CMD_SET_ZERO,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
HANDSHAKE_TIMEOUT_S,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
@@ -216,16 +215,14 @@ class RobstrideMotorsBus(MotorsBusBase):
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _query_status_via_clear_fault(
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
) -> tuple[bool, can.Message | None]:
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
motor_name = self._get_motor_name(motor)
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
def _recv_status_via_clear_fault(
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
@@ -283,7 +280,7 @@ class RobstrideMotorsBus(MotorsBusBase):
faulted_motors = []
for motor_name in self.motors:
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
has_fault, msg = self._query_status_via_clear_fault(motor_name)
if msg is None:
missing_motors.append(motor_name)
elif has_fault:
@@ -508,87 +505,6 @@ class RobstrideMotorsBus(MotorsBusBase):
return responses
def _recv_all_messages_until_quiet(
self,
*,
timeout: float = RUNNING_TIMEOUT,
max_messages: int = 4096,
) -> list[can.Message]:
"""
Receive frames until the bus goes quiet.
Args:
timeout: Poll timeout used for each recv() call. Collection stops
when one recv() times out (quiet gap).
max_messages: Safety cap to prevent unbounded loops.
"""
out: list[can.Message] = []
max_messages = max(1, max_messages)
timeout = max(0.0, timeout)
try:
while len(out) < max_messages:
msg = self._bus().recv(timeout=timeout)
if msg is None:
break
out.append(msg)
except (can.CanError, OSError) as e:
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
return out
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
"""
Decode all received feedback frames and update cached motor states.
Returns:
Set of payload recv_ids that were successfully mapped to motors.
"""
processed_recv_ids: set[int] = set()
for msg in messages:
if len(msg.data) < 1:
logger.debug(
f"Dropping short CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
)
continue
recv_id = int(msg.data[0])
motor_name = self._recv_id_to_motor.get(recv_id)
if motor_name is None:
logger.debug(
f"Unmapped CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
)
continue
self._process_response(motor_name, msg)
processed_recv_ids.add(recv_id)
return processed_recv_ids
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
"""
Drain pending RX frames from the CAN interface.
This is used by higher-level controllers to drop stale feedback before issuing
a fresh read cycle, so subsequent state reads are based on most recent replies.
It should also be called once when a controller instance is created/connected,
to clear residual frames left on the interface from previous sessions.
"""
drained = 0
poll_timeout_s = max(0.0, poll_timeout_s)
max_messages = max(1, max_messages)
try:
while drained < max_messages:
msg = self._bus().recv(timeout=poll_timeout_s)
if msg is None:
break
drained += 1
except (can.CanError, OSError) as e:
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
return drained
def _speed_control(
self,
motor: NameOrID,
@@ -728,14 +644,11 @@ class RobstrideMotorsBus(MotorsBusBase):
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Read every feedback frame until RX goes quiet, then decode all of them.
# This avoids dropping useful frames when responses from different motors interleave.
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
for recv_id, motor_name in recv_id_to_motor.items():
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
@@ -798,10 +711,7 @@ class RobstrideMotorsBus(MotorsBusBase):
try:
self._decode_motor_state(msg.data)
except Exception as e:
logger.warning(
f"Failed to decode response from {motor} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
)
logger.warning(f"Failed to decode response from {motor}: {e}")
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the state cache."""
@@ -938,12 +848,20 @@ class RobstrideMotorsBus(MotorsBusBase):
self._bus().send(msg)
updated_motors.append(motor)
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
for response in responses.values():
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
if payload_motor_name is not None:
self._process_response(payload_motor_name, response)
else:
# Fallback: still attempt to decode based on payload byte0 mapping.
self._decode_motor_state(response.data)
for motor in updated_motors:
recv_id = self._get_motor_recv_id(motor)
if recv_id not in processed_recv_ids:
if recv_id not in responses:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def read_calibration(self) -> dict[str, MotorCalibration]:

View File

@@ -114,8 +114,7 @@ CAN_CMD_SAVE_PARAM = 0xAA
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.003
HANDSHAKE_TIMEOUT_S = 0.05
RUNNING_TIMEOUT = 0.001
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02

View File

@@ -20,7 +20,6 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
@@ -44,7 +43,6 @@ __all__ = [
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
"PI0Config",
"PI0FastConfig",

View File

@@ -49,7 +49,6 @@ from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
@@ -57,7 +56,6 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -90,8 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
"molmoact2".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -154,14 +151,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "molmoact2":
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -179,7 +168,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x", "molmoact2".
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -214,10 +203,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -246,7 +231,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
dataset_meta: Any | None
def make_pre_post_processors(
@@ -422,7 +406,6 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -431,23 +414,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MolmoAct2Config):
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
processors = make_molmoact2_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
@@ -533,10 +499,6 @@ def make_policy(
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
if ds_meta is not None:
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
if callable(set_dataset_feature_metadata):
set_dataset_feature_metadata(ds_meta.features)
kwargs["config"] = cfg

View File

@@ -60,7 +60,6 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True

View File

@@ -124,6 +124,7 @@ class Eagle25VLProcessor(ProcessorMixin):
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(

View File

@@ -14,7 +14,7 @@
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import numpy as np
import torch
@@ -26,14 +26,9 @@ from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
@@ -178,20 +173,19 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
backbone_cfg: dict
action_head_cfg: dict
action_horizon: int
action_dim: int
compute_dtype: str = "float32"
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
# real model

View File

@@ -206,11 +206,7 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(
str(cache_dir),
trust_remote_code=True,
fix_mistral_regex=False,
)
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
proc.tokenizer.padding_side = "left"
return proc

View File

@@ -1 +0,0 @@
../../../../docs/source/policy_molmoact2_README.md

View File

@@ -1,21 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_molmoact2 import MolmoAct2Config
from .modeling_molmoact2 import MolmoAct2Policy
from .processor_molmoact2 import make_molmoact2_pre_post_processors
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]

View File

@@ -1,519 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
CosineDecayWithWarmupSchedulerConfig,
LRSchedulerConfig,
OptimizerConfig,
)
from lerobot.utils.constants import ACTION, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
training steps"; build() is the first point where num_training_steps is known.
"""
num_decay_steps: int | None
def build(self, optimizer, num_training_steps: int):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.peak_lr,
decay_lr=self.decay_lr,
num_warmup_steps=self.num_warmup_steps,
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
).build(optimizer, num_training_steps=num_training_steps)
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
@PreTrainedConfig.register_subclass("molmoact2")
@dataclass
class MolmoAct2Config(PreTrainedConfig):
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
checkpoint_path: str = "allenai/MolmoAct2"
checkpoint_revision: str | None = None
checkpoint_force_download: bool = False
n_obs_steps: int = 1
chunk_size: int = 30
n_action_steps: int = 30
action_mode: str = "both"
inference_action_mode: str | None = None
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
discrete_generation_max_steps: int | None = None
norm_tag: str | None = None
setup_type: str = ""
control_mode: str = ""
image_keys: list[str] = field(default_factory=list)
normalize_language: bool = True
add_setup_tokens: bool = True
add_control_tokens: bool = True
normalize_gripper: bool = False
num_state_tokens: int = 256
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
# image/prompt/state/action token layout. Override only for unusual long prompts.
max_sequence_length: int | None = None
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
expected_max_action_dim: int = 32
# Flow-matching training knobs copied from the original MolmoAct2 training path.
num_flow_timesteps: int = 8
flow_matching_cutoff: float = 1.0
flow_matching_time_offset: float = 0.001
flow_matching_time_scale: float = 0.999
flow_matching_beta_alpha: float = 1.0
flow_matching_beta_beta: float = 1.5
num_inference_steps: int | None = None
mask_action_dim_padding: bool = True
enable_inference_cuda_graph: bool = True
# MolmoAct2-local eval option. When enabled, stochastic continuous action
# generation uses a rollout-local generator derived from eval_seed.
per_episode_seed: bool = False
eval_seed: int | None = None
rtc_config: RTCConfig | None = None
# Default is full finetuning with gradients from the action expert flowing into the VLM.
enable_lora_vlm: bool = False
lora_rank: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_bias: str = "none"
enable_lora_action_expert: bool = False
enable_knowledge_insulation: bool = False
freeze_embedding: bool = True
train_action_expert_only: bool = False
gradient_checkpointing: bool = False
model_dtype: str = "bfloat16"
softmax_auxiliary_loss: bool = True
softmax_auxiliary_loss_scale: float = 1e-4
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
optimizer_lr: float = 1e-5
optimizer_vit_lr: float = 5e-6
optimizer_connector_lr: float = 5e-6
optimizer_action_expert_lr: float = 5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-6
optimizer_weight_decay: float = 0.0
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 200
scheduler_decay_steps: int | None = None
scheduler_decay_lr: float = 1e-6
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES,
"ACTION": NormalizationMode.QUANTILES,
}
)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
super().__post_init__()
if self.action_mode not in {"continuous", "discrete", "both"}:
raise ValueError(
f"Unsupported action_mode={self.action_mode!r}. "
"Expected one of {'continuous', 'discrete', 'both'}."
)
if self.inference_action_mode not in {None, "continuous", "discrete"}:
raise ValueError(
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
"Expected one of {None, 'continuous', 'discrete'}."
)
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
if self.train_action_expert_only and self.action_mode != "continuous":
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
if self.train_action_expert_only and self.enable_lora_vlm:
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
if self.enable_lora_action_expert and not self.enable_lora_vlm:
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
if self.chunk_size < 1:
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
if self.n_action_steps < 1:
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
)
if self.expected_max_action_dim != 32:
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
raise ValueError(
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
)
if self.lora_rank < 1:
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
if self.lora_alpha < 1:
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
if not 0 <= self.lora_dropout <= 1:
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
if self.lora_bias not in {"none", "all", "lora_only"}:
raise ValueError(
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
)
if self.discrete_loss_token_weighting not in {
"none",
"token",
"root_tokens",
"root_subsegments",
"root_subsegments_root_tokens",
}:
raise ValueError(
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
)
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
raise ValueError(
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
)
if self.max_sequence_length is not None and self.max_sequence_length < 1:
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
def inferred_max_sequence_length(
self,
*,
num_images: int | None = None,
state_dim: int | None = None,
action_dim: int | None = None,
action_horizon: int | None = None,
include_discrete_action: bool | None = None,
) -> int:
if self.max_sequence_length is not None:
return int(self.max_sequence_length)
if num_images is None:
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim is None:
state_feature = self.robot_state_feature
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
if action_dim is None:
action_feature = self.action_feature
action_dim = (
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
)
if action_horizon is None:
action_horizon = self.chunk_size
if include_discrete_action is None:
include_discrete_action = self.action_mode in {"discrete", "both"}
return infer_molmoact2_max_sequence_length(
num_images=int(num_images),
state_dim=int(state_dim),
action_dim=int(action_dim),
action_horizon=int(action_horizon),
include_discrete_action=bool(include_discrete_action),
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
self.dataset_feature_names = {}
for key in (ACTION, OBS_STATE):
feature = features.get(key) if isinstance(features, dict) else None
if isinstance(feature, dict) and feature.get("names") is not None:
self.dataset_feature_names[key] = feature["names"]
def validate_features(self) -> None:
"""Validate and set up MolmoAct2 input and output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"MolmoAct2 policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(0,),
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode

View File

@@ -1,17 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa

View File

@@ -1,237 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
import logging
import os
from pathlib import Path
from typing import ClassVar
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_tokenizer_location(
tokenizer_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
local_path = Path(str(tokenizer_path)).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=str(tokenizer_path),
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
class UniversalActionProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["tokenizer"]
tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(tokenizer)
self.bpe_tokenizer = self.tokenizer
def __call__(self, action_chunk: np.array) -> np.array:
from scipy.fft import dct
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
from scipy.fft import idct
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.array],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "UniversalActionProcessor":
from scipy.fft import dct
# Run DCT over all inputs
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert min_vocab_size <= vocab_size, (
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
)
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rounded_tokens = np.around(tokens * scale) - min_token
rounded_tokens = rounded_tokens.astype(int)
string = "".join(map(chr, rounded_tokens))
yield string
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)
@classmethod
def from_pretrained_local(
cls,
pretrained_model_name_or_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> "UniversalActionProcessor":
location = Path(
_resolve_tokenizer_location(
pretrained_model_name_or_path,
revision=revision,
force_download=force_download,
)
)
processor_config = {}
processor_config_path = location / "processor_config.json"
if processor_config_path.exists():
import json
processor_config = json.loads(processor_config_path.read_text())
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
return cls(
tokenizer,
scale=processor_config.get("scale", 10),
vocab_size=processor_config.get("vocab_size", 1024),
min_token=processor_config.get("min_token", 0),
action_dim=processor_config.get("action_dim"),
time_horizon=processor_config.get("time_horizon"),
)

View File

@@ -1,553 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
MolmoAct2 configuration
"""
from typing import Optional, Any
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MolmoAct2VitConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
>>> # Initializing a MolmoAct2VitConfig
>>> configuration = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
>>> model = MolmoAct2VisionTransformer(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
base_config_key = "vit_config"
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
num_hidden_layers: int = 27,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
hidden_act: str = "gelu_pytorch_tanh",
layer_norm_eps: float = 1e-6,
image_default_input_size: tuple[int, int] = (378, 378),
image_patch_size: int = 14,
image_num_pos: int = 577,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
initializer_range: float = 0.02,
float32_attention: bool = True,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.image_default_input_size = image_default_input_size
self.image_patch_size = image_patch_size
self.image_num_pos = image_num_pos
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.initializer_range = initializer_range
self.float32_attention = float32_attention
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
class MolmoAct2AdapterConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
>>> vit_config = MolmoAct2VitConfig()
>>> adapter_config = MolmoPoolingConfig()
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
>>> # Accessing the model configuration
>>> vit_configuration = model.vit_config
>>> adapter_configuration = model.adapter_config
```"""
model_type = "molmoact2"
base_config_key = "adapter_config"
def __init__(
self,
vit_layers: tuple = (-3, -9),
pooling_attention_mask: bool = False,
hidden_size: int = 1152,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
float32_attention: bool = True,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
hidden_act: str = "silu",
intermediate_size: int = 18944,
text_hidden_size: int = 3584,
image_feature_dropout: float = 0.0,
initializer_range: float = 0.02,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.vit_layers = vit_layers
self.pooling_attention_mask = pooling_attention_mask
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.float32_attention = float32_attention
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.text_hidden_size = text_hidden_size
self.image_feature_dropout = image_feature_dropout
self.initializer_range = initializer_range
class MolmoAct2TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
>>> # Initializing a MolmoAct2TextConfig
>>> configuration = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2TextModel (with random weights)
>>> model = MolmoAct2TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"blocks.*.self_attn.att_proj": "colwise",
"blocks.*.self_attn.attn_out": "rowwise",
"blocks.*.mlp.ff_proj": "colwise",
"blocks.*.mlp.ff_out": "rowwise",
}
base_model_pp_plan = {
"wte": (["input_ids"], ["inputs_embeds"]),
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
"ln_f": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
hidden_size: int = 3584,
num_attention_heads: int = 28,
num_key_value_heads: int | None = 4,
head_dim: int = 128,
vocab_size: int = 152064,
additional_vocab_size: int = 128,
qkv_bias: bool = True,
num_hidden_layers: int = 48,
intermediate_size: int = 18944,
hidden_act: str = "silu",
embedding_dropout: float = 0.0,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
max_position_embeddings: int = 4096,
rope_theta: float = 1000000.0,
rope_scaling: dict[str, Any] = None,
rope_scaling_layers: list[int] | None = None,
use_qk_norm: bool = False,
qk_norm_type: str = "olmo",
layer_norm_eps: int = 1e-6,
norm_after: bool = False,
initializer_range: float = 0.02,
use_cache=True,
tie_word_embeddings=False,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.vocab_size = vocab_size
self.additional_vocab_size = additional_vocab_size
self.qkv_bias = qkv_bias
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.embedding_dropout = embedding_dropout
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_scaling_layers = rope_scaling_layers
self.use_qk_norm = use_qk_norm
self.qk_norm_type = qk_norm_type
self.layer_norm_eps = layer_norm_eps
self.norm_after = norm_after
self.initializer_range = initializer_range
self.use_cache = use_cache
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
class MolmoAct2ActionExpertConfig(PretrainedConfig):
r"""Configuration for the MolmoAct2 modern action expert."""
model_type = "molmoact2_action_expert"
base_config_key = "action_expert_config"
def __init__(
self,
max_action_horizon: int = 32,
max_action_dim: int = 32,
hidden_size: int = 1024,
num_layers: int = 32,
num_heads: int = 16,
mlp_ratio: float = 8.0 / 3.0,
ffn_multiple_of: int = 256,
timestep_embed_dim: int = 256,
dropout: float = 0.0,
attn_dropout: float = 0.0,
context_layer_norm: bool = True,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
rope: bool = True,
causal_attn: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.max_action_horizon = max_action_horizon
self.max_action_dim = max_action_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.ffn_multiple_of = ffn_multiple_of
self.timestep_embed_dim = timestep_embed_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.context_layer_norm = context_layer_norm
self.qk_norm = qk_norm
self.qk_norm_eps = qk_norm_eps
self.rope = rope
self.causal_attn = causal_attn
def to_dict(self):
output = super().to_dict()
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
# them out of the public nested config avoids duplicated sources of truth.
output.pop("max_action_horizon", None)
output.pop("max_action_dim", None)
return output
class MolmoAct2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
Example:
```python
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
>>> # Initializing a MolmoAct2VitConfig
>>> vit_config = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2AdapterConfig
>>> adapter_config = MolmoAct2AdapterConfig()
>>> # Initializing a MolmoAct2TextConfig
>>> text_config = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2Config
>>> configuration = MolmoAct2Config(
>>> vit_config=vit_config,
>>> adapter_config=adapter_config,
>>> text_config=text_config,
>>> image_start_token_id=151936,
>>> image_end_token_id=151937,
>>> image_patch_id=151938,
>>> image_col_id=151939,
>>> low_res_image_start_token_id=151940,
>>> image_low_res_id=151942,
>>> frame_start_token_id=151943,
>>> frame_end_token_id=151944,
>>> )
>>> # Initializing a model
>>> model = MolmoAct2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
sub_configs = {
"text_config": MolmoAct2TextConfig,
"vit_config": MolmoAct2VitConfig,
"adapter_config": MolmoAct2AdapterConfig,
"action_expert_config": MolmoAct2ActionExpertConfig,
}
def __init__(
self,
vit_config: MolmoAct2VitConfig = None,
adapter_config: MolmoAct2AdapterConfig = None,
text_config: MolmoAct2TextConfig = None,
action_expert_config: MolmoAct2ActionExpertConfig = None,
image_start_token_id: int = None,
low_res_image_start_token_id: int = None,
image_end_token_id: int = None,
image_low_res_id: int = None,
image_patch_id: int = None,
image_col_id: int = None,
frame_start_token_id: int = None,
frame_end_token_id: int = None,
use_frame_special_tokens: bool = True,
initializer_range: float = 0.02,
add_action_expert: bool = True,
max_action_dim: int = 32,
max_action_horizon: int = 30,
n_obs_steps: int = 30,
action_mode: str = "both",
state_format: str = "discrete",
flow_matching_num_steps: int = 10,
flow_matching_cutoff: float = 1.0,
flow_matching_time_offset: float = 0.001,
flow_matching_time_scale: float = 0.999,
flow_matching_beta_alpha: float = 1.0,
flow_matching_beta_beta: float = 1.5,
mask_action_dim_padding: bool = True,
enable_depth_reasoning: bool = False,
depth_mode: int = 2,
num_depth_codes: int = 100,
action_expert_depth_gate: bool = False,
action_expert_depth_gate_per_layer: bool = False,
action_expert_depth_gate_init_bias: float = -4.0,
action_output_token_id: int = None,
action_start_token_id: int = None,
action_end_token_id: int = None,
action_token_start_id: int = None,
num_action_tokens: int = 0,
depth_output_token_id: int = None,
depth_start_token_id: int = None,
depth_end_token_id: int = None,
depth_token_start_id: int = None,
num_depth_tokens: int = 0,
state_start_token_id: int = None,
state_end_token_id: int = None,
state_token_start_id: int = None,
num_state_tokens: int = 0,
add_setup_tokens: bool = True,
add_control_tokens: bool = True,
norm_stats_filename: str = "norm_stats.json",
**kwargs,
):
super().__init__(**kwargs)
if vit_config is None:
self.vit_config = MolmoAct2VitConfig()
elif isinstance(vit_config, dict):
self.vit_config = MolmoAct2VitConfig(**vit_config)
else:
self.vit_config = vit_config
if adapter_config is None:
self.adapter_config = MolmoAct2AdapterConfig()
elif isinstance(adapter_config, dict):
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
else:
self.adapter_config = adapter_config
if text_config is None:
self.text_config = MolmoAct2TextConfig()
elif isinstance(text_config, dict):
self.text_config = MolmoAct2TextConfig(**text_config)
else:
self.text_config = text_config
self.add_action_expert = bool(add_action_expert)
if not self.add_action_expert:
self.action_expert_config = None
elif action_expert_config is None:
self.action_expert_config = MolmoAct2ActionExpertConfig(
max_action_horizon=max_action_horizon,
max_action_dim=max_action_dim,
num_layers=self.text_config.num_hidden_layers,
)
elif isinstance(action_expert_config, dict):
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
else:
self.action_expert_config = action_expert_config
if self.add_action_expert:
self.action_expert_config.max_action_dim = int(max_action_dim)
self.action_expert_config.max_action_horizon = int(max_action_horizon)
self._validate_release_action_config(
state_format=state_format,
)
self.image_start_token_id = image_start_token_id
self.low_res_image_start_token_id = low_res_image_start_token_id
self.image_end_token_id = image_end_token_id
self.image_low_res_id = image_low_res_id
self.image_high_res_id = image_patch_id
self.image_patch_id = image_patch_id
self.image_col_id = image_col_id
self.frame_start_token_id = frame_start_token_id
self.frame_end_token_id = frame_end_token_id
self.use_frame_special_tokens = use_frame_special_tokens
self.initializer_range = initializer_range
self.max_action_dim = max_action_dim
self.max_action_horizon = max_action_horizon
self.n_obs_steps = n_obs_steps
self.action_mode = action_mode
self.state_format = state_format
self.flow_matching_num_steps = flow_matching_num_steps
self.flow_matching_cutoff = flow_matching_cutoff
self.flow_matching_time_offset = flow_matching_time_offset
self.flow_matching_time_scale = flow_matching_time_scale
self.flow_matching_beta_alpha = flow_matching_beta_alpha
self.flow_matching_beta_beta = flow_matching_beta_beta
self.mask_action_dim_padding = mask_action_dim_padding
self.enable_depth_reasoning = enable_depth_reasoning
self.depth_mode = depth_mode
self.num_depth_codes = num_depth_codes
self.action_expert_depth_gate = action_expert_depth_gate
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
self.action_output_token_id = action_output_token_id
self.action_start_token_id = action_start_token_id
self.action_end_token_id = action_end_token_id
self.action_token_start_id = action_token_start_id
self.num_action_tokens = num_action_tokens
self.depth_output_token_id = depth_output_token_id
self.depth_start_token_id = depth_start_token_id
self.depth_end_token_id = depth_end_token_id
self.depth_token_start_id = depth_token_start_id
self.num_depth_tokens = num_depth_tokens
self.state_start_token_id = state_start_token_id
self.state_end_token_id = state_end_token_id
self.state_token_start_id = state_token_start_id
self.num_state_tokens = num_state_tokens
self.add_setup_tokens = add_setup_tokens
self.add_control_tokens = add_control_tokens
self.norm_stats_filename = norm_stats_filename
@staticmethod
def _validate_release_action_config(
*,
state_format: str,
) -> None:
if state_format != "discrete":
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
@property
def image_num_patch(self):
assert self.vit_config is not None
return self.vit_config.image_num_patch
@property
def num_attention_heads(self):
return self.text_config.num_attention_heads
@property
def num_key_value_heads(self):
return self.text_config.num_key_value_heads
@property
def head_dim(self):
return self.text_config.head_dim
@property
def num_hidden_layers(self):
return self.text_config.num_hidden_layers
@property
def hidden_size(self):
return self.text_config.hidden_size
@property
def vocab_size(self):
return self.text_config.vocab_size
@property
def max_position_embeddings(self):
return self.text_config.max_position_embeddings
MolmoAct2VitConfig.register_for_auto_class()
MolmoAct2AdapterConfig.register_for_auto_class()
MolmoAct2TextConfig.register_for_auto_class()
MolmoAct2ActionExpertConfig.register_for_auto_class()
MolmoAct2Config.register_for_auto_class()

View File

@@ -1,564 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
valid_images,
to_numpy_array,
)
from transformers.image_transforms import convert_to_rgb
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def select_tiling(h, w, patch_size, max_num_crops):
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
if i * j <= max_num_crops:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
# The original size can be zero in rare cases if the image is smaller than the margin
# In those cases letting the scale become infinite means the tiling is based on the
# other side, or falls back to the smallest tiling
with np.errstate(divide="ignore"):
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def build_overlapping_crops(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Decompose an image into a set of overlapping crops
:return crop_arr: [n_crops, h, w, 3] The crops
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
the crops were extracted from, what patch in `crop_arr` it corresponds to
"""
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
assert base_image_input_size[0] == base_image_input_size[1]
left_margin, right_margin = overlap_margins
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * image_patch_size
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Decide how to tile the image, to account for the overlap margins we compute the tiling
# as if we had an image without the margins and were using a crop size without the margins
tiling = select_tiling(
original_image_h - total_margin_pixels,
original_image_w - total_margin_pixels,
crop_window_size,
max_crops,
)
src = resize_image(
image,
[
tiling[0] * crop_window_size + total_margin_pixels,
tiling[1] * crop_window_size + total_margin_pixels,
],
resample,
)
src = normalize_image(src, image_mean, image_std)
# Now we have to split the image into crops, and track what patches came from
# where in `patch_idx_arr`
n_crops = tiling[0] * tiling[1]
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
on_crop = 0
for i in range(tiling[0]):
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
# which results in overlapping crop windows
y0 = i * crop_window_size
for j in range(tiling[1]):
x0 = j * crop_window_size
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
patch_idx += on_crop * crop_patch_h * crop_patch_w
# Mask out idx that are in the overlap region
if i != 0:
patch_idx[:left_margin, :] = -1
if j != 0:
patch_idx[:, :left_margin] = -1
if i != tiling[0] - 1:
patch_idx[-right_margin:, :] = -1
if j != tiling[1] - 1:
patch_idx[:, -right_margin:] = -1
patch_idx_arr[on_crop] = patch_idx
on_crop += 1
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
# so it is ordered left-to-right order
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
# Now get the parts not in the overlap region, so it should map each patch in `src`
# to the correct patch it should come from in `crop_arr`
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
src.shape[0] // image_patch_size,
src.shape[1] // image_patch_size,
)
return crop_arr, patch_idx_arr
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
crop_mode: str = "overlap-and-resize-c2",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each (low-res, high-res) image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
pooling_w = image_pooling_w
pooling_h = image_pooling_h
crop_patch_w = base_image_input_size[1] // base_image_input_d
crop_patch_h = base_image_input_size[0] // base_image_input_d
if crop_mode == "resize":
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [np.array([resized_h, resized_w, 0, 0])]
return (
np.stack(image_grid, 0),
batch_pixels_to_patches(resized, image_patch_size),
resize_idx,
)
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
crop_arr, patch_idx_arr = build_overlapping_crops(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
# Finally do the same for the global image
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
crop_arr = np.concatenate([resized, crop_arr], 0)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
# Global image goes first, so the order of patches in previous crops gets increased
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
pooling_idx = np.concatenate([resize_idx, pooling_idx])
image_grid = [np.array([resized_h, resized_w, h, w])]
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
max_crops: int | None
overlap_margins: list[int] | None
crop_mode: str | None
patch_size: int | None
pooling_size: list[int] | None
class MolmoAct2ImageProcessor(BaseImageProcessor):
r"""
Constructs a MolmoAct2 image processor that preprocesses images for the model.
Args:
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use when resizing the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `8`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to 14):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
The pooling size of the vision adapter.
"""
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
def __init__(
self,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
crop_mode: str = "overlap-and-resize-c2",
patch_size: int = 14,
pooling_size: list[int] = [2, 2],
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 378, "width": 378}
size = get_size_dict(size, default_to_square=True)
self.size = size
self.resample = resample
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.crop_mode = crop_mode
self.patch_size = patch_size
self.pooling_size = pooling_size
def preprocess(
self,
images: ImageInput,
size: dict[str, int] | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
max_crops: int | None = None,
overlap_margins: list[int] | None = None,
crop_mode: str | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Args:
images (`ImageInput`):
Image to preprocess.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `self.max_crops`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values`: The preprocessed images.
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
- `image_grids`: The image grids.
- `image_num_crops`: The number of crops for each image.
"""
if size is not None:
if "height" not in size or "width" not in size:
raise ValueError("size must contain 'height' and 'width' keys.")
else:
size = {**self.size}
base_image_input_size = [size["height"], size["width"]]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
crop_mode = crop_mode or self.crop_mode
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
if images is not None:
images = self.fetch_images(images)
images = make_flat_list_of_images(images)
if images is not None and not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
data = {}
if images is not None:
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
batch_num_crops = []
for image in images:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
crop_mode,
)
batch_grids.append(image_grid)
batch_crops.append(crops)
batch_pooled_patches_idx.append(pooled_idx)
batch_num_crops.append(crops.shape[0])
pixel_values = np.concatenate(batch_crops, 0)
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
image_grids = np.concatenate(batch_grids, 0)
image_num_crops = np.array(batch_num_crops)
data.update(
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2ImageProcessor.register_for_auto_class()

View File

@@ -1,748 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
import torch
from torch.nn import functional as F
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
@dataclass
class _ActionFlowInputs:
trajectory: torch.Tensor
context: Any
modulations: Sequence[Any]
action_dim_is_pad: torch.Tensor | None
@dataclass
class _ActionFlowCudaGraph:
key: tuple[Any, ...]
graph: torch.cuda.CUDAGraph
static_inputs: _ActionFlowInputs
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphLayerStage:
residual: torch.Tensor
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphPostStage:
graph: torch.cuda.CUDAGraph
attn_context: torch.Tensor
@dataclass
class _DepthDecodeCudaGraph:
cache_key: tuple[Any, ...]
pre_graph: torch.cuda.CUDAGraph
token_ids: torch.Tensor
cos: torch.Tensor
sin: torch.Tensor
positions: torch.Tensor
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphSpec:
eligible: bool
cache_key_prefix: tuple[Any, ...]
num_hidden_layers: int
head_dim: int
num_attention_heads: int
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return 0
seq_len = past_key_values.get_seq_length()
if torch.is_tensor(seq_len):
return int(seq_len.item())
return int(seq_len)
def _cache_max_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return -1
max_len = past_key_values.get_max_cache_shape()
if torch.is_tensor(max_len):
return int(max_len.item())
return int(max_len)
def _iter_cache_key_values(
past_key_values: Cache,
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
layers = getattr(past_key_values, "layers", None)
if layers is not None:
for layer in layers:
yield getattr(layer, "keys", None), getattr(layer, "values", None)
return
for layer in past_key_values:
yield layer[0], layer[1]
class _DepthDecodeStaticLayerCache:
is_compileable = False
is_sliding = False
def __init__(self, max_cache_len: int) -> None:
self.max_cache_len = int(max_cache_len)
self.cumulative_length = 0
self.keys: torch.Tensor | None = None
self.values: torch.Tensor | None = None
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
bsz, n_heads = key_states.shape[:2]
self.keys = torch.empty(
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
dtype=key_states.dtype,
device=key_states.device,
)
self.values = torch.empty(
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
dtype=value_states.dtype,
device=value_states.device,
)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
*args,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.keys is None:
self._allocate(key_states, value_states)
start = self.cumulative_length
end = start + key_states.shape[-2]
if end > self.max_cache_len:
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
self.keys[:, :, start:end, :].copy_(key_states)
self.values[:, :, start:end, :].copy_(value_states)
self.cumulative_length = end
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
def get_seq_length(self) -> int:
return self.cumulative_length
def get_max_cache_shape(self) -> int:
return -1
def reset(self) -> None:
self.cumulative_length = 0
class _DepthDecodeStaticCache(Cache):
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
text_config = config.get_text_config(decoder=True)
super().__init__(
layers=[
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
for _ in range(text_config.num_hidden_layers)
]
)
def get_seq_length(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_max_cache_shape()
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class ActionCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.enabled = True
self.action_flow_graph: _ActionFlowCudaGraph | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
action_model = self.model
if not self.enabled:
return False
if action_model.training or action_model._require_action_expert().training:
return False
if inputs.trajectory.device.type != "cuda":
return False
def all_on_cuda():
yield inputs.trajectory
for k, v in inputs.context.kv_contexts:
yield k
yield v
for t in (
inputs.context.cross_mask,
inputs.context.self_mask,
inputs.context.valid_action,
inputs.action_dim_is_pad,
):
if t is not None:
yield t
if inputs.context.rope_cache is not None:
yield from inputs.context.rope_cache
for step in inputs.modulations:
yield step.conditioning
for block_modulation in step.block_modulations:
yield from block_modulation
yield from step.final_modulation
return all(t.device.type == "cuda" for t in all_on_cuda())
def run_action_flow(
self,
inputs: _ActionFlowInputs,
steps: int,
run_loop,
) -> torch.Tensor:
key = _cuda_graph_key(inputs, steps)
cache = self.action_flow_graph
if cache is None or cache.key != key:
static_inputs = _clone_static_inputs(inputs)
graph, output = _capture_cuda_graph(
lambda: run_loop(static_inputs, steps),
inputs.trajectory.device,
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
)
cache = _ActionFlowCudaGraph(
key=key,
graph=graph,
static_inputs=static_inputs,
output=output,
)
self.action_flow_graph = cache
else:
_copy_inputs_(cache.static_inputs, inputs)
cache.graph.replay()
return cache.output.clone()
class DepthDecodeCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.backbone = model.model
self.enabled = True
self.graph: _DepthDecodeCudaGraph | None = None
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
return _DepthDecodeStaticCache(
config=self.model.config.text_config,
max_cache_len=max_cache_len,
)
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
static = self.graph_spec
if static is None:
cfg = self.backbone.transformer.config
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
static = _DepthDecodeCudaGraphSpec(
eligible=(
not cfg.norm_after
and cfg.rope_scaling_layers is None
and getattr(rotary_emb, "rope_type", None) == "default"
and cfg._attn_implementation == "sdpa"
),
cache_key_prefix=(
cfg.hidden_size,
cfg.num_attention_heads,
cfg.num_key_value_heads,
cfg.head_dim,
cfg.num_hidden_layers,
cfg.use_qk_norm,
cfg.qk_norm_type,
cfg._attn_implementation,
),
num_hidden_layers=cfg.num_hidden_layers,
head_dim=cfg.head_dim,
num_attention_heads=cfg.num_attention_heads,
)
self.graph_spec = static
return static
def can_use(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
) -> bool:
if not self.enabled or self.model.training or self.backbone.transformer.training:
return False
if next_input_ids.device.type != "cuda":
return False
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
return False
if not isinstance(past_key_values, _DepthDecodeStaticCache):
return False
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
return False
return self._depth_decode_spec().eligible
def _depth_decode_key(
self,
next_input_ids: torch.Tensor,
attention_bias: torch.Tensor,
) -> tuple[Any, ...]:
device = next_input_ids.device
return (
self._depth_decode_spec().cache_key_prefix,
device.type,
device.index,
self.model.lm_head.weight.dtype,
attention_bias.shape[-1],
)
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
emb = self.backbone.transformer.rotary_emb
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
def _depth_decode_pre_layer(
self,
layer_idx: int,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
residual = hidden_states
hidden_states = block.attn_norm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attention.head_dim)
qkv = attention.att_proj(hidden_states)
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
value_states = value_states.view(hidden_shape)
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
if apply_qk_norm and not norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
if norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
return residual, query_states, key_states, value_states
def _depth_decode_pre0(
self,
token_ids: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
inputs_embeds = self.model._embed_base_tokens(token_ids)
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
def _depth_decode_post_layer(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
input_shape = residual.shape[:-1]
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
attn_output = attention.attn_out(attn_output)
hidden_states = residual + block.dropout(attn_output)
residual = hidden_states
hidden_states = block.ff_norm(hidden_states)
hidden_states = block.mlp(hidden_states)
hidden_states = residual + block.dropout(hidden_states)
return hidden_states
def _depth_decode_post_and_pre_next(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
def _depth_decode_last_post(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self.backbone.transformer.ln_f(hidden_states)
def _build_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
text_config = self.backbone.transformer.config
device = next_input_ids.device
dtype = self.model.lm_head.weight.dtype
static = self._depth_decode_spec()
num_layers = static.num_hidden_layers
head_dim = static.head_dim
max_cache_len = int(attention_bias.shape[-1])
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
sin = torch.empty_like(cos)
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
context_shape = (1, 1, static.num_attention_heads, head_dim)
token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(cos, sin, past_length=past_length)
pre_graph, pre_output = _capture_cuda_graph(
lambda: self._depth_decode_pre0(token_ids, cos, sin),
device,
)
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
post_graphs = []
for layer_idx in range(num_layers - 1):
stage = stages[-1]
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
graph, output = _capture_cuda_graph(
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
self._depth_decode_post_and_pre_next(
layer_idx,
stage.residual,
attn_context,
cos,
sin,
)
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
last_stage = stages[-1]
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
last_graph, last_output = _capture_cuda_graph(
lambda: self._depth_decode_last_post(
num_layers - 1,
last_stage.residual,
last_attn_context,
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
return _DepthDecodeCudaGraph(
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
pre_graph=pre_graph,
token_ids=token_ids,
cos=cos,
sin=sin,
positions=positions,
stages=tuple(stages),
post_graphs=tuple(post_graphs),
output=last_output,
)
def _get_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
key = self._depth_decode_key(next_input_ids, attention_bias)
decode_graph = self.graph
if decode_graph is None or decode_graph.cache_key != key:
decode_graph = self._build_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
self.graph = decode_graph
else:
decode_graph.token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
return decode_graph
def _run_depth_decode_attention_core(
self,
layer_idx: int,
stage: _DepthDecodeCudaGraphLayerStage,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
cache_position: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
attention = self.backbone.transformer.blocks[layer_idx].self_attn
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
stage.key,
stage.value,
layer_idx,
cache_kwargs,
)
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
stage.query,
key_states,
value_states,
attn_mask=attention_bias,
dropout_p=0.0,
is_causal=False,
)
return attn_output.transpose(1, 2)
def run(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
past_length: int,
) -> tuple[torch.Tensor, Cache]:
end = past_length + 1
decode_graph = self._get_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
cache_position = decode_graph.positions[past_length:end]
attention_bias_q = attention_bias[:, :, past_length:end, :end]
decode_graph.pre_graph.replay()
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
attn_context = self._run_depth_decode_attention_core(
layer_idx,
decode_graph.stages[layer_idx],
past_key_values=past_key_values,
attention_bias=attention_bias_q,
cache_position=cache_position,
cos=decode_graph.cos,
sin=decode_graph.sin,
)
post_graph.attn_context.copy_(attn_context)
post_graph.graph.replay()
return decode_graph.output, past_key_values
def _cuda_graph_tensor_signature(
tensor: torch.Tensor | None,
) -> tuple[Any, ...] | None:
if tensor is None:
return None
return (
tuple(tensor.shape),
tuple(tensor.stride()),
str(tensor.dtype),
str(tensor.device),
)
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
sig(context.cross_mask),
sig(context.self_mask),
sig(context.valid_action),
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
)
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return tuple(
(
sig(step.conditioning),
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
tuple(sig(t) for t in step.final_modulation),
)
for step in modulations
)
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
sig(inputs.trajectory),
_cuda_graph_context_signature(inputs.context),
_cuda_graph_modulation_signature(inputs.modulations),
sig(inputs.action_dim_is_pad),
int(steps),
)
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None:
return None
static = torch.empty_strided(
tuple(tensor.shape),
tuple(tensor.stride()),
device=tensor.device,
dtype=tensor.dtype,
)
static.copy_(tensor)
return static
def _clone_static_context(context: Any) -> Any:
rope_cache = None
if context.rope_cache is not None:
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
return context.__class__(
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
cross_mask=_clone_static_tensor(context.cross_mask),
self_mask=_clone_static_tensor(context.self_mask),
valid_action=_clone_static_tensor(context.valid_action),
rope_cache=rope_cache,
)
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
return tuple(
step.__class__(
conditioning=_clone_static_tensor(step.conditioning),
block_modulations=tuple(
tuple(_clone_static_tensor(t) for t in block_modulation)
for block_modulation in step.block_modulations
),
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
)
for step in modulations
)
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
return _ActionFlowInputs(
trajectory=_clone_static_tensor(inputs.trajectory),
context=_clone_static_context(inputs.context),
modulations=_clone_static_modulations(inputs.modulations),
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
)
def _copy_context_(dst: Any, src: Any) -> None:
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
dst_k.copy_(src_k)
dst_v.copy_(src_v)
if src.cross_mask is not None:
dst.cross_mask.copy_(src.cross_mask)
if src.self_mask is not None:
dst.self_mask.copy_(src.self_mask)
if src.valid_action is not None:
dst.valid_action.copy_(src.valid_action)
if src.rope_cache is not None:
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
dst_tensor.copy_(src_tensor)
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
dst.trajectory.copy_(src.trajectory)
_copy_context_(dst.context, src.context)
if src.action_dim_is_pad is not None:
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _capture_cuda_graph(
fn,
device: torch.device,
*,
after_warmup=None,
) -> tuple[torch.cuda.CUDAGraph, Any]:
warmup_stream = torch.cuda.Stream(device=device)
warmup_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(warmup_stream):
fn()
torch.cuda.current_stream(device).wait_stream(warmup_stream)
if after_warmup is not None:
after_warmup()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = fn()
return graph, output

File diff suppressed because it is too large Load Diff

View File

@@ -1,431 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
import numpy as np
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
Unpack,
ProcessingKwargs,
ProcessorMixin,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.utils import logging
from transformers import AutoTokenizer
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
logger = logging.get_logger(__name__)
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
IMAGE_TOKENS = [
IMAGE_PATCH_TOKEN,
IM_COL_TOKEN,
IM_START_TOKEN,
LOW_RES_IMAGE_START_TOKEN,
FRAME_START_TOKEN,
IM_END_TOKEN,
FRAME_END_TOKEN,
IMAGE_LOW_RES_TOKEN,
]
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
"""MolmoAct2 processor kwargs"""
images_kwargs: MolmoAct2ImagesKwargs
videos_kwargs: MolmoAct2VideoProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": True,
},
"videos_kwargs": {"return_metadata": True},
}
class MolmoAct2Processor(ProcessorMixin):
attributes = ["image_processor", "video_processor", "tokenizer"]
optional_attributes = [
"chat_template",
"time_mode",
"image_use_col_tokens",
"use_single_crop_col_tokens",
"use_single_crop_start_token",
"video_use_col_tokens",
"use_frame_special_tokens",
]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor: MolmoAct2ImageProcessor = None,
video_processor: MolmoAct2VideoProcessor = None,
tokenizer: AutoTokenizer = None,
chat_template: str | None = None,
image_use_col_tokens: bool | None = True,
use_single_crop_col_tokens: bool | None = None,
use_single_crop_start_token: bool | None = True,
video_use_col_tokens: bool | None = False,
use_frame_special_tokens: bool | None = True,
**kwargs,
) -> None:
super().__init__(
image_processor,
video_processor,
tokenizer,
chat_template=chat_template,
)
self.image_use_col_tokens = image_use_col_tokens
self.use_single_crop_col_tokens = use_single_crop_col_tokens
self.use_single_crop_start_token = use_single_crop_start_token
self.video_use_col_tokens = video_use_col_tokens
self.use_frame_special_tokens = use_frame_special_tokens
self.image_placeholder_token = IMAGE_PROMPT
self.video_placeholder_token = VIDEO_PROMPT
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
def get_image_tokens(self, image_grid: np.ndarray):
resized_h, resized_w, height, width = image_grid
if int(height) == 0 or int(width) == 0:
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
]
return np.concatenate(joint)
per_row = np.full(width, IMAGE_PATCH_TOKEN)
if self.image_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [height]),
[IM_END_TOKEN],
]
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[image_start_token],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
] + joint
return np.concatenate(joint)
def get_video_string(
self,
video_grid: np.ndarray,
timestamps: np.ndarray,
):
if self.use_frame_special_tokens:
start_token_id = FRAME_START_TOKEN
end_token_id = FRAME_END_TOKEN
else:
start_token_id = IM_START_TOKEN
end_token_id = IM_END_TOKEN
num_frames, h, w = video_grid
video_string: str = ""
for frame_idx, frame_time in enumerate(timestamps):
# `per-frame-compact` time mode
prev_space = " " if frame_idx > 0 else ""
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
video_string += frame_prefix
per_row = np.full(w, IMAGE_PATCH_TOKEN)
if self.video_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
extra_tokens = np.tile(per_row, [h])
video_tokens = [
[start_token_id],
extra_tokens,
[end_token_id],
]
video_string += "".join(np.concatenate(video_tokens, 0))
return video_string
def insert_bos(
self,
input_ids: np.ndarray,
attention_mask: np.ndarray,
bos_token_id: int,
pad_token_id: int,
):
"""
Args:
input_ids: [B, S] array with left padding
attention_mask: [B, S] array (0 for pad, 1 for valid)
bos_token_id: int
pad_token_id: int
Returns:
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
attention_mask_out: same shape as input_ids_out
"""
need_to_expand = len(input_ids.shape) == 1
if need_to_expand:
input_ids = input_ids[None, :]
attention_mask = attention_mask[None, :]
B, S = input_ids.shape
# Handle zero-length sequence
if S == 0:
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
if bos_already_present:
if need_to_expand:
input_ids = input_ids[0]
attention_mask = attention_mask[0]
return input_ids, attention_mask
else:
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
tgt_idx = src_idx + 1 # shit right
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
# flatten valid_positions
flat_vals = input_ids[valid_mask]
flat_batch = batch_idx[valid_mask]
flat_tgt = tgt_idx[valid_mask]
new_input_ids[flat_batch, flat_tgt] = flat_vals
new_attention_mask[flat_batch, flat_tgt] = 1
insert_pos = first_valid_index
new_input_ids[np.arange(B), insert_pos] = bos_token_id
new_attention_mask[np.arange(B), insert_pos] = 1
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
def __call__(
self,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
images: ImageInput = None,
videos: VideoInput = None,
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
) -> BatchFeature:
"""
Args:
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
- `"timestamps"`: `np.ndarray` of shape (T,)
- `"sampled_fps"`: `float` (optional)
- `"sampling_augmentation"`: `str` (optional)
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
`BatchFeature`: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
Returned when `images` is not `None`.
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
Returned when `videos` is not `None`.
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
MolmoAct2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
image_grids = image_inputs["image_grids"]
else:
image_inputs = {}
image_grids = None
if videos is not None:
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grids = videos_inputs["video_grids"]
# If user has not requested video metadata, pop it
if "return_metadata" not in kwargs:
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
else:
videos_inputs = {}
video_grids = None
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
if image_grids is not None:
index = 0
for i in range(len(text)):
num_images = text[i].count(self.image_placeholder_token)
image_grids_i = image_grids[index : index + num_images]
for image_grid in image_grids_i:
image_tokens = self.get_image_tokens(image_grid)
image_string = "".join(image_tokens)
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
index += num_images
if video_grids is not None:
index = 0
for i in range(len(text)):
num_videos = text[i].count(self.video_placeholder_token)
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
video_string = self.get_video_string(
video_grid,
metadata.timestamps,
)
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
index += num_videos
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
input_ids = text_inputs["input_ids"]
attention_mask = text_inputs["attention_mask"]
input_ids = np.array(input_ids)
attention_mask = np.array(attention_mask)
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
input_ids, attention_mask = self.insert_bos(
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
)
if return_mm_token_type_ids:
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
text_inputs["token_type_ids"] = token_type_ids.tolist()
text_inputs["input_ids"] = input_ids.tolist()
text_inputs["attention_mask"] = attention_mask.tolist()
return BatchFeature(
data={**text_inputs, **image_inputs, **videos_inputs},
tensor_type=return_tensors,
)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
MolmoAct2Processor.register_for_auto_class()

View File

@@ -1,997 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
from contextlib import redirect_stdout
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union
from collections.abc import Callable
import numpy as np
import requests
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
SizeDict,
validate_kwargs,
)
from transformers.video_utils import (
VideoInput,
is_valid_video,
make_batched_videos,
make_batched_metadata,
VideoMetadata,
)
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import (
is_av_available,
is_decord_available,
is_torchcodec_available,
is_yt_dlp_available,
TensorType,
logging,
to_numpy,
)
logger = logging.get_logger(__name__)
MAX_VIDEO_FPS = 8
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
if len(image.shape) == 3:
is_video = False
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
else:
is_video = True
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
if is_video:
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
else:
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: ImageInput,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
pooling_w = image_pooling_w
pooling_h = image_pooling_h
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [h, w]
return (
image_grid,
batch_pixels_to_patches(resized, image_patch_size),
pooling_idx,
)
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
if video_fps % sampling_fps != 0:
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
def read_video_decord(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the Decord backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
import importlib
decord = importlib.import_module("decord")
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
duration = time_stamps[-1][1] - time_stamps[0][0]
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="decord",
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
target_timestamps = np.array(target_timestamps)
offset = time_stamps[0, 0]
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
ix = np.minimum(ix, len(time_stamps) - 1)
video = vr.get_batch(ix).asnumpy()
metadata.update(
{
"frames_indices": target_timestamps * video_fps,
"height": video.shape[1],
"width": video.shape[2],
}
)
return video, metadata
def read_video_torchcodec(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using torchcodec decoder.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
torchcodec = importlib.import_module("torchcodec")
decoder = torchcodec.decoders.VideoDecoder(
video_path,
# Interestingly `exact` mode takes less than approximate when we load the whole video
seek_mode="exact",
# Allow FFmpeg decide on the number of threads for efficiency
num_ffmpeg_threads=0,
)
# If the first frame starts at > 0, we effectively clip the video starting at that time
# since (most) video players would also skip to that time
time_offset = decoder.metadata.begin_stream_seconds_from_content
# Note this duration does assume we started playing at `time_offset`
duration = decoder.metadata.duration_seconds
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
duration=duration,
video_backend="torchcodec",
height=decoder.metadata.height,
width=decoder.metadata.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
# out-of-bounds, to handle this we sanity check then clip them
assert all(x >= 0 for x in target_timestamps)
assert all(x < duration + 1e-6 for x in target_timestamps)
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
# exact boundary value, we should still get the first/last frame anyway
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
# Note we avoid using numpy ops here to reduce floating precision issues
timestamps = [x + time_offset for x in target_timestamps]
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
video = (
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
) # Convert to THWC format
target_timestamps = np.array(target_timestamps)
metadata.frames_indices = target_timestamps * metadata.fps
return video, metadata
def read_video_pyav(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the PyAV backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
av = importlib.import_module("av")
with av.open(video_path) as container:
video_stream = container.streams.video[0]
fps = video_stream.average_rate or video_stream.guessed_rate
it = container.decode(video=0)
frames = list(it)
stream = container.streams.video[0]
start = frames[0].pts * stream.time_base
container_end = stream.duration
if container_end is not None:
container_end *= stream.time_base
if container_end is None or container_end < frames[-1].pts:
# Some problem with stream duration, so use the frame PTS directly
# and guess the duration of the last frame
end = frames[-1].pts * stream.time_base + 1 / fps
else:
end = container_end
duration = float(end - start)
metadata = VideoMetadata(
total_num_frames=len(frames),
fps=float(fps),
duration=float(duration),
video_backend="pyav",
height=video_stream.height,
width=video_stream.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
offset = float(start)
target_timestamps = np.array(target_timestamps)
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
indices = np.minimum(indices, len(end_time_stamps) - 1)
video = np.stack(
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
axis=0,
)
metadata.frames_indices = target_timestamps * fps
return video, metadata
VIDEO_DECODERS = {
"decord": read_video_decord,
"torchcodec": read_video_torchcodec,
"pyav": read_video_pyav,
}
def load_video(
video: VideoInput,
backend: str = "decord",
sample_timestamps_fn: Callable | None = None,
**kwargs,
):
"""
Loads `video` to a numpy array.
Args:
video (`VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
backend (`str`, *optional*, defaults to `"decord"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
"""
# Early exit if provided an array or `PIL` frames
if not isinstance(video, str):
metadata = [None] * len(video)
return video, metadata
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
import importlib
yt_dlp = importlib.import_module("yt_dlp")
buffer = BytesIO()
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video, timeout=10).content)
elif os.path.isfile(video):
file_obj = video
else:
raise TypeError(
"Incorrect format used for video. Should be an url linking to an video or a local path."
)
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend == "opencv":
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
if (
(not is_decord_available() and backend == "decord")
or (not is_torchcodec_available() and backend == "torchcodec")
or (not is_av_available() and backend == "pyav")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
return video, metadata
def get_target_fps(
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: tuple[float],
) -> float:
"""
Get the target fps that best spans the video and has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if "uniform" in frame_sample_mode:
if num_frames_sampled_at_fps > max_frames:
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames, choose the one with higher density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
if selected_target_fps is None:
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
patch_size: int | None
pooling_size: list[int] | None
frame_sample_mode: str | None
max_fps: int | None
sampling_fps: int | None
class MolmoAct2VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BILINEAR
size = {"height": 378, "width": 378}
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
patch_size = 14
pooling_size = [3, 3]
do_sample_frames = True
frame_sample_mode = "uniform_last_frame"
max_fps = 2
sampling_fps = 2
valid_kwargs = MolmoAct2VideoProcessorKwargs
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
super().__init__(**kwargs)
if self.size is not None and (
self.size.get("height", None) is None or self.size.get("width", None) is None
):
raise ValueError("size must contain 'height' and 'width' keys.")
def _further_process_kwargs(
self,
size: SizeDict | None = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if size is not None and ("height" not in size or "width" not in size):
raise ValueError("size must contain 'height' and 'width' keys.")
return super()._further_process_kwargs(size=size, **kwargs)
def sample_times(
self,
metadata: VideoMetadata,
frame_sample_mode: str,
num_frames: int,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Time-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
man_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
duration = metadata.duration or metadata.total_num_frames / metadata.fps
if frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
# Try larger and larger FPSs until we hit one that can't span the video
target_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if num_frames / candidate_fps < duration:
break
target_fps = candidate_fps
times = np.arange(0, num_frames) / target_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= num_frames
else:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
return times
else:
raise NotImplementedError(frame_sample_mode)
def sample_frames(
self,
metadata: VideoMetadata,
frame_sample_mode: str | None = None,
num_frames: int | None = None,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Frame-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
max_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
total_num_frames = metadata.total_num_frames
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
duration = total_num_frames / metadata.fps
if total_num_frames <= 2:
return np.arange(total_num_frames).astype(int)
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(metadata.fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
return indices
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
elif frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
selected_target_fps = get_target_fps(
metadata.fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
metadata.fps,
)
return indices
else:
raise NotImplementedError(frame_sample_mode)
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
raise ImportError(
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
)
if is_decord_available():
backend = "decord"
elif is_torchcodec_available():
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `torchcodec`."
)
backend = "torchcodec"
else:
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `PyAV`."
)
backend = "pyav"
if isinstance(video_url_or_urls, list):
return list(
zip(
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
)
)
else:
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
def _decode_and_sample_videos(
self,
videos: VideoInput,
video_metadata: VideoMetadata | dict,
do_sample_frames: bool | None = None,
sample_indices_fn: Callable | None = None,
sample_timestamps_fn: Callable | None = None,
):
"""
Decode input videos and sample frames if needed.
"""
videos = make_batched_videos(videos)
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
# Framed-based sampling if an array video is passed
# Otherwise, time-based sampling with decoding
if is_valid_video(videos[0]) and do_sample_frames:
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
sampled_videos = []
sampled_metadata = []
for video, metadata in zip(videos, video_metadata):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
sampled_metadata.append(metadata)
videos = sampled_videos
video_metadata = sampled_metadata
elif not is_valid_video(videos[0]):
if sample_indices_fn is None:
logger.warning(
"do_sample_frames is False, but video array is not provided: "
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
)
if isinstance(videos[0], list):
raise ValueError("A list of images is not supported for video input!")
else:
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
return videos, video_metadata
def _prepare_input_videos(
self,
videos: VideoInput,
**kwargs,
) -> list[np.ndarray]:
processed_videos = [to_numpy(video) for video in videos]
return processed_videos
def preprocess(
self,
videos: VideoInput,
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
do_sample_frames = kwargs.pop("do_sample_frames")
video_metadata = kwargs.pop("video_metadata")
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
sample_timestamps_fn = partial(self.sample_times, **kwargs)
videos, video_metadata = self._decode_and_sample_videos(
videos,
video_metadata=video_metadata,
do_sample_frames=do_sample_frames,
sample_indices_fn=sample_indices_fn,
sample_timestamps_fn=sample_timestamps_fn,
)
videos = self._prepare_input_videos(videos=videos)
kwargs = self._further_process_kwargs(**kwargs)
return_metadata = kwargs.pop("return_metadata")
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
if return_metadata:
preprocessed_videos["video_metadata"] = video_metadata
return preprocessed_videos
def _preprocess(
self,
videos: list[np.ndarray],
size: SizeDict | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess a video for the model.
Args:
videos (`VideoInput`):
Video to preprocess.
size (`SizeDict`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values_videos`: The preprocessed videos.
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
- `video_grids`: The video grids.
"""
if size.height is None or size.width is None:
raise ValueError("size must contain 'height' and 'width' keys.")
base_image_input_size = [size.height, size.width]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
for video in videos:
all_crops = []
pooled_patches_idx = []
for frame in video:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
frame,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
)
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
pooled_patches_idx.append(pooled_idx_with_offset)
all_crops.append(crops)
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
all_crops = np.concatenate(all_crops, 0)
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
batch_grids.append(video_grid)
batch_crops.append(all_crops)
batch_pooled_patches_idx.append(pooled_patches_idx)
video_grids = np.stack(batch_grids, 0)
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2VideoProcessor.register_for_auto_class()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -142,15 +141,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -237,13 +227,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -265,16 +258,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = rotary_emb(dummy_tensor, position_ids)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma_layer.self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -282,13 +274,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma_layer.self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -496,9 +488,8 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -508,39 +499,36 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -919,7 +907,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = clone_past_key_values(past_key_values)
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -139,15 +138,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -234,13 +224,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -262,16 +255,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = rotary_emb(dummy_tensor, position_ids)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma_layer.self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -279,13 +271,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma_layer.self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -493,9 +485,8 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -505,39 +496,36 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -892,7 +880,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = clone_past_key_values(past_key_values)
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -1 +0,0 @@
/home/maxime/github/robots/lerobot/docs/source/policy_vla_jepa_README.md

View File

@@ -1,10 +0,0 @@
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"VLAJEPANewLineProcessor",
"make_vla_jepa_pre_post_processors",
]

View File

@@ -1,327 +0,0 @@
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions

View File

@@ -1,138 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,615 +0,0 @@
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] — multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str — task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model

View File

@@ -1,139 +0,0 @@
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
transition = dict(transition)
a = action.clone()
a[..., 6] = (a[..., 6] >= 0.5).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes gripper dim (index 6) after unnormalization.
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
Only applied when action has >= 7 dimensions.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
transition = dict(transition)
a = action.clone()
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(PreSnapGripperProcessorStep())
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(BinarizeGripperProcessorStep())
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
def complementary_data(self, complementary_data):
return complementary_data
def transform_features(self, features):
return features

View File

@@ -1,103 +0,0 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)

View File

@@ -1,404 +0,0 @@
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)

View File

@@ -24,7 +24,6 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import datasets
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
@@ -361,41 +360,6 @@ def test_add_frame_image_pil(image_dataset):
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@pytest.mark.parametrize(
"dtype,np_dtype,values,assert_fn",
[
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
],
ids=["float32", "int64", "bool"],
)
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
):
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
captured = {}
original_from_dict = datasets.Dataset.from_dict
def _from_dict_spy(cls, mapping, *args, **kwargs):
captured["state"] = mapping["state"]
return original_from_dict(mapping, *args, **kwargs)
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
dataset.save_episode()
dataset.finalize()
assert "state" in captured
assert isinstance(captured["state"], np.ndarray)
assert captured["state"].shape == (2,)
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
def test_set_image_transforms_applies_transparently(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})

View File

@@ -1,140 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``lerobot.datasets.video_utils.VideoDecoderCache``.
These cover the LRU bounding + file-handle release behaviour added to prevent
unbounded growth when iterating over datasets with many distinct video files
(observed: ~35 GB anon-rss per DataLoader worker on an 8 k-file dataset).
"""
import shutil
from pathlib import Path
import pytest
pytest.importorskip("torchcodec", reason="torchcodec is required (install lerobot[dataset])")
from lerobot.datasets.video_utils import VideoDecoderCache # noqa: E402
TEST_ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts" / "encoded_videos"
SRC_CLIP = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
def _make_distinct_clips(tmp_path: Path, n: int) -> list[Path]:
"""Copy the small reference mp4 to ``n`` distinct paths.
The cache keys on absolute path, so distinct paths force distinct cache entries
even though the file contents are identical.
"""
assert SRC_CLIP.exists(), f"missing test artifact {SRC_CLIP}"
paths = []
for i in range(n):
dst = tmp_path / f"clip_{i:04d}.mp4"
shutil.copyfile(SRC_CLIP, dst)
paths.append(dst)
return paths
class TestVideoDecoderCacheBounded:
def test_default_cache_is_bounded(self):
"""The default cache must have a finite ``max_size`` to bound RSS growth."""
cache = VideoDecoderCache()
assert cache.max_size is not None, "default cache must be bounded"
assert cache.max_size > 0
def test_size_capped_at_max_size(self, tmp_path):
"""``get_decoder`` for >``max_size`` distinct paths must NOT grow without bound."""
paths = _make_distinct_clips(tmp_path, n=5)
cache = VideoDecoderCache(max_size=2)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 2
def test_evicts_least_recently_used(self, tmp_path):
"""Re-accessing an entry must promote it; the LRU entry is the one evicted."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=2)
cache.get_decoder(paths[0])
cache.get_decoder(paths[1])
cache.get_decoder(paths[0]) # promote paths[0] to MRU; paths[1] is now LRU
cache.get_decoder(paths[2]) # should evict paths[1]
assert str(paths[0]) in cache # MRU stays
assert str(paths[1]) not in cache # LRU evicted
assert str(paths[2]) in cache # newest stays
def test_eviction_closes_file_handle(self, tmp_path):
"""Evicting an entry must close its fsspec file handle (otherwise we leak FDs)."""
paths = _make_distinct_clips(tmp_path, n=2)
cache = VideoDecoderCache(max_size=1)
cache.get_decoder(paths[0])
# Reach into the cache to capture the handle before it is evicted. This is
# the only assertion in the suite that touches a private attribute, and it
# is the most direct way to prove the file descriptor is actually released.
evicted_handle = cache._cache[str(paths[0])][1]
assert evicted_handle.closed is False
cache.get_decoder(paths[1]) # forces eviction of paths[0]
assert evicted_handle.closed is True
def test_clear_closes_all_file_handles(self, tmp_path):
"""``clear()`` must close every cached file handle."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=10)
for p in paths:
cache.get_decoder(p)
handles = [entry[1] for entry in cache._cache.values()]
assert all(not h.closed for h in handles)
cache.clear()
assert cache.size() == 0
assert all(h.closed for h in handles)
def test_hit_does_not_reopen_or_evict(self, tmp_path):
"""A cache hit must return the same decoder instance without touching the cap."""
paths = _make_distinct_clips(tmp_path, n=1)
cache = VideoDecoderCache(max_size=2)
first = cache.get_decoder(paths[0])
second = cache.get_decoder(paths[0])
assert first is second
assert cache.size() == 1
def test_unbounded_when_max_size_none(self, tmp_path):
"""``max_size=None`` preserves the legacy unbounded behaviour."""
paths = _make_distinct_clips(tmp_path, n=4)
cache = VideoDecoderCache(max_size=None)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 4
def test_env_var_overrides_default(self, tmp_path, monkeypatch):
"""``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var sets the default ``max_size``."""
monkeypatch.setenv("LEROBOT_VIDEO_DECODER_CACHE_SIZE", "3")
cache = VideoDecoderCache()
assert cache.max_size == 3
paths = _make_distinct_clips(tmp_path, n=5)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 3

File diff suppressed because it is too large Load Diff

View File

@@ -1 +0,0 @@
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""

View File

@@ -1,22 +0,0 @@
from dataclasses import dataclass
@dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
def get_config(variant: str) -> Config:
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
if variant == "dummy":
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
if variant == "gemma_300m":
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
if variant == "gemma_2b":
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
raise ValueError(f"Unknown variant: {variant}")

View File

@@ -1,300 +0,0 @@
from typing import Literal
import torch
from torch import nn
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
class PaliGemmaWithExpertModel(nn.Module):
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Force enable gradient checkpointing if we're in training mode and the model supports it
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
if not self.gemma_expert.model.gradient_checkpointing:
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
self.gemma_expert.model.gradient_checkpointing = True
use_gradient_checkpointing = True
# Debug gradient checkpointing status
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
print(f"Model training mode: {self.training}")
print(
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
)
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
print(
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
)
self._debug_gc_printed = True
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
models = [self.paligemma.model.language_model, self.gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(
layer.input_layernorm, hidden_states, adarms_cond[i]
)
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
self.paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
)
# Old code removed - now using compute_layer_complete function above
# final norm
# Define final norm computation function for gradient checkpointing
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values

View File

@@ -1,79 +0,0 @@
import torch
import torch.nn.functional as F # noqa: N812
def resize_with_pad_torch(
images: torch.Tensor,
height: int,
width: int,
mode: str = "bilinear",
) -> torch.Tensor:
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
Args:
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
height: Target height
width: Target width
mode: Interpolation mode ('bilinear', 'nearest', etc.)
Returns:
Resized and padded tensor with same shape format as input
"""
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
if images.shape[-1] <= 4: # Assume channels-last format
channels_last = True
# Convert to channels-first for torch operations
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
else:
channels_last = False
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
batch_size, channels, cur_height, cur_width = images.shape
# Calculate resize ratio
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
# Resize
resized_images = F.interpolate(
images,
size=(resized_height, resized_width),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
# Calculate padding
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
mode="constant",
value=constant_value,
)
# Convert back to original format if needed
if channels_last:
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
if batch_size == 1 and images.shape[0] == 1:
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
return padded_images

View File

@@ -1,471 +0,0 @@
import copy
import logging
import math
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device):
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
class PI0Pytorch(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pi05 = config.pi05
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True] if self.pi05 else [False, False],
precision=config.dtype,
)
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
if self.pi05:
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
else:
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
torch.set_float32_matmul_precision("high")
if config.pytorch_compile_mode is not None:
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
# The upstream OpenPI module verifies a site-package Transformers patch here.
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def is_gradient_checkpointing_enabled(self):
"""Check if gradient checkpointing is enabled."""
return self.gradient_checkpointing_enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
def _preprocess_observation(self, observation, *, train=True):
"""Helper method to preprocess observation."""
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
return (
list(observation.images.values()),
list(observation.image_masks.values()),
observation.tokenized_prompt,
observation.tokenized_prompt_mask,
observation.state,
)
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# Get batch size from the first dimension of the concatenated tensors
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
if not self.pi05:
if self.state_proj.weight.dtype == torch.float32:
state = state.to(torch.float32)
# Embed state
def state_proj_func(state):
return self.state_proj(state)
state_emb = self._apply_checkpoint(state_proj_func, state)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.action_in_proj.out_features,
min_period=4e-3,
max_period=4.0,
device=timestep.device,
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
if not self.pi05:
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
# Apply MLP layers
def mlp_func(action_time_emb):
x = self.action_time_mlp_in(action_time_emb)
x = F.silu(x) # swish == silu
return self.action_time_mlp_out(x)
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
adarms_cond = None
else:
# time MLP (for adaRMS)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x) # swish == silu
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=True
)
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
# Prepare attention masks
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Apply gradient checkpointing if enabled
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
# Apply gradient checkpointing to final action projection if enabled
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad()
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = observation.state.shape[0]
if noise is None:
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
noise = self.sample_noise(actions_shape, device)
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=False
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step - use new tensor assignment instead of in-place operation
x_t = x_t + dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Prepare attention masks
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)

View File

@@ -1,179 +0,0 @@
import logging
from collections.abc import Sequence
import torch
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
logger = logging.getLogger("openpi")
# Constants moved from model.py
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
IMAGE_RESOLUTION = (224, 224)
def preprocess_observation_pytorch(
observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
):
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
This function avoids complex type annotations that can cause torch.compile issues.
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad_torch(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
image = image / 2.0 + 0.5
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = image.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
image = torch.nn.functional.interpolate(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=image.device)
grid_y = torch.linspace(-1, 1, height, device=image.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
image = torch.nn.functional.grid_sample(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = (
0.7 + torch.rand(1, device=image.device) * 0.6
) # Random factor between 0.7 and 1.3
image = image * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = (
0.6 + torch.rand(1, device=image.device) * 0.8
) # Random factor between 0.6 and 1.4
mean = image.mean(dim=[1, 2, 3], keepdim=True)
image = (image - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = (
0.5 + torch.rand(1, device=image.device) * 1.0
) # Random factor between 0.5 and 1.5
gray = image.mean(dim=-1, keepdim=True)
image = gray + (image - gray) * saturation_factor
# Clamp values to [0, 1]
image = torch.clamp(image, 0, 1)
# Back to [-1, 1]
image = image * 2.0 - 1.0
# Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
else:
out_masks[key] = observation.image_masks[key]
# Create a simple object with the required attributes instead of using the complex Observation class
class SimpleProcessedObservation:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return SimpleProcessedObservation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)

View File

@@ -1,101 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi05_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
benchmark_runtime(
eager_model.sample_actions,
compiled_model.sample_actions,
sample_kwargs,
"pi05.sample_actions",
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()

View File

@@ -14,56 +14,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
import gc
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
from lerobot.configs import PreTrainedConfig # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
assert_processor_inputs_match_lerobot,
clone_batch,
deterministic_openpi_forward_preprocess,
fix_reference_state_dict,
fixed_flow_sampling,
load_openpi_reference_state_dict,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
from lerobot.types import PolicyAction # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COMPILE_MODE = "default"
FORWARD_RTOL = 1e-4
FORWARD_ATOL = 1e-4
SAMPLE_RTOL = 1e-2
SAMPLE_ATOL = 5e-3
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_DATASET_STATS = {
OBS_STATE: {
"observation.state": {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"action": {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
@@ -92,15 +88,6 @@ DUMMY_DATASET_STATS = {
}
@pytest.fixture(autouse=True)
def cleanup_cuda_after_test():
yield
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class PI05BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
@@ -109,163 +96,341 @@ class PI05BaseOriginalConfig:
precision: str = "float32"
pi05: bool = True
dtype: str = "float32"
pytorch_compile_mode: str | None = None
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
config.device = str(DEVICE)
config.dtype = "float32"
config.compile_model = compile_model
config.compile_mode = COMPILE_MODE
config.gradient_checkpointing = gradient_checkpointing
def instantiate_lerobot_pi05(
from_pretrained: bool = False,
) -> tuple[
PI05Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
else:
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI05Policy(config)
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
policy.to(DEVICE)
policy.config.device = str(DEVICE)
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
return policy, preprocessor
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi05_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
def instantiate_original_pi05():
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
config = PI05BaseOriginalConfig()
policy = PI0Pytorch(config)
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名LeRobot 的
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
# loader所以这里手动复用同一份 tensor避免把权重别名差异误判成模型差异。
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
assert missing_keys == []
assert unexpected_keys == []
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
return policy
def create_dummy_data():
batch_size = 2
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
torch.manual_seed(0)
raw_batch = create_dummy_data()
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
openpi_actions = openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
lerobot_pi05,
lerobot_batch,
openpi_observation,
compare_state=False,
)
batch_size = raw_batch[OBS_STATE].shape[0]
noise = torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=DEVICE,
)
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
return lerobot_batch, openpi_observation, openpi_actions, noise, time
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
compile_model=compile_model,
gradient_checkpointing=gradient_checkpointing,
)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
class PI05Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description (pi05 processor handles all text formatting)
tasks = batch.get("task", ["Pick up the object"] * batch_size)
if isinstance(tasks, str):
tasks = [tasks] * batch_size
elif len(tasks) == 1:
tasks = tasks * batch_size
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
state = batch["observation.state"]
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
from lerobot.policies.pi05.modeling_pi05 import pad_vector
state = pad_vector(state, DUMMY_STATE_DIM)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create pi05-formatted prompts that include state information
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
if gradient_checkpointing:
lerobot_pi05.train()
else:
lerobot_pi05.eval()
original_pi05.eval()
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
with deterministic_openpi_forward_preprocess(original_pi05):
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
openpi_loss = openpi_losses.mean(dim=(1, 2))
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
# Create raw observation object (before preprocessing)
raw_observation = PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
lerobot_pi05.eval()
original_pi05.eval()
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi05_original_vs_lerobot():
"""Test PI05 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi05(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
openpi_actions = original_pi05.sample_actions(
device=DEVICE,
observation=openpi_observation,
noise=noise,
num_steps=10,
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
print("Testing LeRobot with own preprocessing...")
lerobot_pi05.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
with torch.no_grad():
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
def test_pi05_forward_matches_openpi():
assert_forward_matches()
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
def test_pi05_sample_actions_match_openpi():
assert_sample_actions_match_openpi()
def test_pi05_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(gradient_checkpointing=True)
def test_pi05_compile_forward_matches_openpi():
assert_forward_matches(compile_model=True)
def test_pi05_compile_sample_actions_match_openpi():
assert_sample_actions_match_openpi(compile_model=True)
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4

View File

@@ -1,99 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi0 import PI0Config # noqa: E402
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
"state": torch.randn(1, config.max_state_dim, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi0_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
benchmark_runtime(
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()

View File

@@ -14,56 +14,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
import gc
import os
from copy import deepcopy
from typing import Any
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
from lerobot.configs import PreTrainedConfig # noqa: E402
from lerobot.policies.pi0 import PI0Policy # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
assert_processor_inputs_match_lerobot,
clone_batch,
deterministic_openpi_forward_preprocess,
fix_reference_state_dict,
fixed_flow_sampling,
load_openpi_reference_state_dict,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
from lerobot.types import PolicyAction # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COMPILE_MODE = "default"
FORWARD_RTOL = 1e-4
FORWARD_ATOL = 1e-4
SAMPLE_RTOL = 1e-2
SAMPLE_ATOL = 5e-3
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_DATASET_STATS = {
OBS_STATE: {
"observation.state": {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"action": {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
@@ -92,15 +87,6 @@ DUMMY_DATASET_STATS = {
}
@pytest.fixture(autouse=True)
def cleanup_cuda_after_test():
yield
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class PI0BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
@@ -109,156 +95,333 @@ class PI0BaseOriginalConfig:
precision: str = "float32"
pi05: bool = False
dtype: str = "float32"
pytorch_compile_mode: str | None = None
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
config.device = str(DEVICE)
config.dtype = "float32"
config.compile_model = compile_model
config.compile_mode = COMPILE_MODE
config.gradient_checkpointing = gradient_checkpointing
def instantiate_lerobot_pi0(
from_pretrained: bool = False,
) -> tuple[
PI0Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
else:
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI0Policy(config)
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
policy.to(DEVICE)
policy.config.device = str(DEVICE)
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
return policy, preprocessor
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi0_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
def instantiate_original_pi0():
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
assert missing_keys == []
assert unexpected_keys == []
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
config = PI0BaseOriginalConfig()
policy = PI0Pytorch(config)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
return policy
def create_dummy_data():
batch_size = 2
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
torch.manual_seed(0)
raw_batch = create_dummy_data()
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
)
openpi_actions = openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
)
assert_processor_inputs_match_lerobot(
lerobot_pi0,
lerobot_batch,
openpi_observation,
compare_state=True,
)
batch_size = raw_batch[OBS_STATE].shape[0]
noise = torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=DEVICE,
)
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
return lerobot_batch, openpi_observation, openpi_actions, noise, time
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
compile_model=compile_model,
gradient_checkpointing=gradient_checkpointing,
)
original_pi0 = instantiate_original_pi0()
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
lerobot_pi0,
lerobot_preprocessor,
)
class PI0Observation:
"""Observation class that matches the original OpenPI format."""
if gradient_checkpointing:
lerobot_pi0.train()
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
# Single string: add newline if not present, then convert to list
if not tasks.endswith("\n"):
tasks = f"{tasks}\n"
tasks = [tasks]
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
# List of strings: add newline to each if not present
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
if len(tasks) == 1:
# Expand to batch size
tasks = tasks * batch_size
if len(tasks) != batch_size:
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
# If task is neither string nor list of strings, leave unchanged
else:
lerobot_pi0.eval()
original_pi0.eval()
# Default task if not provided
tasks = ["Pick up the object\n"] * batch_size
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
with deterministic_openpi_forward_preprocess(original_pi0):
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
openpi_loss = openpi_losses.mean(dim=(1, 2))
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
original_pi0 = instantiate_original_pi0()
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
lerobot_pi0,
lerobot_preprocessor,
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lerobot_pi0.eval()
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi0(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
openpi_actions = original_pi0.sample_actions(
device=DEVICE,
observation=openpi_observation,
noise=noise,
num_steps=10,
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
with torch.no_grad():
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
def test_pi0_forward_matches_openpi():
assert_forward_matches()
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
def test_pi0_sample_actions_match_openpi():
assert_sample_actions_match_openpi()
def test_pi0_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(gradient_checkpointing=True)
def test_pi0_compile_forward_matches_openpi():
assert_forward_matches(compile_model=True)
def test_pi0_compile_sample_actions_match_openpi():
assert_sample_actions_match_openpi(compile_model=True)
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4

View File

@@ -1 +0,0 @@
"""Utilities shared by PI0/PI05 policy tests."""

View File

@@ -1,291 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
import numpy as np
import safetensors.torch
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
@dataclass
class OpenPIObservation:
state: torch.Tensor
images: dict[str, torch.Tensor]
image_masks: dict[str, torch.Tensor]
tokenized_prompt: torch.Tensor
tokenized_prompt_mask: torch.Tensor
token_ar_mask: torch.Tensor
token_loss_mask: torch.Tensor
@lru_cache(maxsize=1)
def paligemma_tokenizer():
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
def clone_batch(batch: dict) -> dict:
return {
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
}
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
if tensor.shape[-1] > target_dim:
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
return (tensor - mean) / (std + 1e-8)
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
return 2.0 * (tensor - q01) / denom - 1.0
def openpi_model_state_from_raw(
batch: dict[str, torch.Tensor],
*,
action_dim: int,
dataset_stats: dict[str, dict[str, torch.Tensor]],
pi05: bool,
) -> torch.Tensor:
state = batch[OBS_STATE].to(dtype=torch.float32)
if pi05:
state = quantile_normalize(state, dataset_stats[OBS_STATE])
else:
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
return pad_last_dim(state, action_dim)
def openpi_model_actions_from_raw(
batch: dict[str, torch.Tensor],
*,
action_dim: int,
dataset_stats: dict[str, dict[str, torch.Tensor]],
pi05: bool,
) -> torch.Tensor:
actions = batch[ACTION].to(dtype=torch.float32)
if pi05:
actions = quantile_normalize(actions, dataset_stats[ACTION])
else:
actions = mean_std_normalize(actions, dataset_stats[ACTION])
return pad_last_dim(actions, action_dim)
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
tasks = batch.get("task")
if tasks is None:
raise ValueError("The parity batch must include a task prompt.")
if isinstance(tasks, str):
return [tasks] * batch_size
if len(tasks) == 1:
return [tasks[0]] * batch_size
if len(tasks) != batch_size:
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
return list(tasks)
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
state_np = normalized_state.detach().cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
prompts = []
for task, state in zip(tasks, discretized_states, strict=True):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, state))
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
return prompts
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
tokenized = paligemma_tokenizer()(
prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=max_token_len,
return_tensors="pt",
)
tokens = tokenized["input_ids"].to(device)
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
return tokens, masks
def make_openpi_observation_from_raw(
batch: dict[str, torch.Tensor],
*,
action_dim: int,
max_token_len: int,
dataset_stats: dict[str, dict[str, torch.Tensor]],
pi05: bool,
) -> OpenPIObservation:
batch_size = batch[OBS_STATE].shape[0]
device = batch[OBS_STATE].device
state = openpi_model_state_from_raw(
batch,
action_dim=action_dim,
dataset_stats=dataset_stats,
pi05=pi05,
)
tasks = _tasks_from_raw(batch, batch_size)
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
images = {
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
for key in IMAGE_KEYS
}
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
return OpenPIObservation(
state=state,
images=images,
image_masks=image_masks,
tokenized_prompt=tokens,
tokenized_prompt_mask=masks,
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
)
def assert_processor_inputs_match_lerobot(
lerobot_policy,
lerobot_batch: dict[str, torch.Tensor],
openpi_observation: OpenPIObservation,
*,
compare_state: bool,
):
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
torch.testing.assert_close(
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
)
torch.testing.assert_close(
openpi_observation.tokenized_prompt_mask,
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
rtol=0,
atol=0,
)
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
for openpi_mask, lerobot_mask in zip(
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
):
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
if compare_state:
torch.testing.assert_close(
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
)
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
return safetensors.torch.load_file(cache_dir / "model.safetensors")
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
fixed_state_dict = dict(state_dict)
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
return fixed_state_dict
@contextmanager
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
original_sample_noise = model.sample_noise
original_sample_time = model.sample_time
def sample_noise(shape, device):
if tuple(shape) != tuple(noise.shape):
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
return noise.to(device=device)
def sample_time(batch_size, device):
if batch_size != time.shape[0]:
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
return time.to(device=device)
model.sample_noise = sample_noise
model.sample_time = sample_time
try:
yield
finally:
model.sample_noise = original_sample_noise
model.sample_time = original_sample_time
@contextmanager
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
otherwise compare two different image tensors rather than two model implementations. The context manager
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
"""
original_preprocess_observation = openpi_policy._preprocess_observation
def preprocess_observation(observation, *, train=True):
return original_preprocess_observation(observation, train=False)
openpi_policy._preprocess_observation = preprocess_observation
try:
yield
finally:
openpi_policy._preprocess_observation = original_preprocess_observation

View File

@@ -1,207 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections.abc import Callable
import torch
from torch._dynamo.utils import counters, guard_failures
from torch.profiler import ProfilerActivity
FORWARD_RTOL = 1e-5
FORWARD_ATOL = 5e-2
SAMPLE_RTOL = 1e-5
SAMPLE_ATOL = 1e-2
COMPILE_MODE = "max-autotune"
STEADY_STATE_WARMUPS = 3
STEADY_STATE_REPEATS = 3
def make_compile_config(config_cls, *, compile_model):
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
def counter_total(name):
return sum(counters.get(name, {}).values())
def compile_snapshot():
return {
"graph_breaks": counter_total("graph_break"),
"recompiles": counter_total("recompiles"),
"recompile_limits": counter_total("recompile_limit"),
"unique_graphs": counters["stats"].get("unique_graphs", 0),
}
def reset_compile_state():
torch._dynamo.reset()
counters.clear()
guard_failures.clear()
def clone_cuda_graph_output(output):
if torch.is_tensor(output):
return output.clone()
if isinstance(output, tuple):
return tuple(clone_cuda_graph_output(item) for item in output)
if isinstance(output, list):
return [clone_cuda_graph_output(item) for item in output]
if isinstance(output, dict):
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
return output
def run_model_step(fn: Callable, kwargs: dict):
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
torch.compiler.cudagraph_mark_step_begin()
return fn(**kwargs)
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
reset_compile_state()
explanation = torch._dynamo.explain(fn)(**kwargs)
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
assert explanation.graph_break_count == 0, (
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
)
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
print(
f"{label} capture: graphs={explanation.graph_count}, "
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
f"guards={len(explanation.out_guards or [])}"
)
return explanation
@torch.no_grad()
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
eager_forward = eager_model.forward(**forward_kwargs)
compiled_forward = compiled_model.forward(**forward_kwargs)
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
eager_actions = eager_model.sample_actions(**sample_kwargs)
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
@torch.no_grad()
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
reset_compile_state()
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
first_snapshot = compile_snapshot()
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
second_snapshot = compile_snapshot()
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
third_snapshot = compile_snapshot()
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
)
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
)
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
@torch.no_grad()
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
run_warmups(eager_fn, kwargs)
run_warmups(compiled_fn, kwargs)
torch.cuda.synchronize()
eager_metrics = profile_callable(eager_fn, kwargs)
compiled_metrics = profile_callable(compiled_fn, kwargs)
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
print(
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
f"{compiled_metrics['host_wall_ms']:.3f}, "
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
f"{compiled_metrics['cuda_launch_count']}, "
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
f"{compiled_metrics['profiler_event_count']}, "
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
f"{compiled_metrics['peak_mem_mib']:.1f}"
)
assert eager_metrics["cuda_event_ms"] > 0
assert compiled_metrics["cuda_event_ms"] > 0
assert eager_metrics["profiler_event_count"] > 0
assert compiled_metrics["profiler_event_count"] > 0
return eager_metrics, compiled_metrics
def run_warmups(fn: Callable, kwargs: dict):
for _ in range(STEADY_STATE_WARMUPS):
run_model_step(fn, kwargs)
torch.cuda.synchronize()
def profile_callable(fn: Callable, kwargs: dict):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
host_start = time.perf_counter()
start_event.record()
for _ in range(STEADY_STATE_REPEATS):
run_model_step(fn, kwargs)
end_event.record()
torch.cuda.synchronize()
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
with torch.profiler.profile(
activities=[ProfilerActivity.CPU],
) as profiler:
run_model_step(fn, kwargs)
torch.cuda.synchronize()
key_averages = profiler.key_averages()
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
cuda_launch_count = sum(
event.count
for event in key_averages
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
)
profiler_event_count = sum(event.count for event in key_averages)
return {
"cuda_event_ms": cuda_event_ms,
"host_wall_ms": host_wall_ms,
"cpu_self_time_ms": cpu_self_time_ms,
"cuda_launch_count": cuda_launch_count,
"profiler_event_count": profiler_event_count,
"peak_mem_mib": peak_mem_mib,
}

View File

@@ -1,273 +0,0 @@
#!/usr/bin/env python
"""Shared fixtures and helpers for VLA-JEPA tests."""
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import pytest
import torch
from PIL import Image
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
# ---------------------------------------------------------------------------
# Shared constants
# ---------------------------------------------------------------------------
BATCH_SIZE = 2
ACTION_DIM = 3
STATE_DIM = 4
IMAGE_SIZE = 8
ACTION_HORIZON = 4
N_ACTION_STEPS = 2
NUM_VIDEO_FRAMES = 3
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def set_seed_all(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def make_config(
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> VLAJEPAConfig:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
device="cpu",
chunk_size=action_horizon,
n_action_steps=min(N_ACTION_STEPS, action_horizon),
action_dim=action_dim,
state_dim=state_dim,
num_video_frames=num_video_frames,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=1,
)
config.validate_features()
return config
def make_train_batch(
batch_size: int = BATCH_SIZE,
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, 1, state_dim),
ACTION: torch.randn(batch_size, action_horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def make_inference_batch(
batch_size: int = BATCH_SIZE,
state_dim: int = STATE_DIM,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, state_dim),
"task": ["pick up the cube"] * batch_size,
}
# ---------------------------------------------------------------------------
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
# ---------------------------------------------------------------------------
class _FakeLanguageLayer(nn.Module):
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
return (hidden,)
class _FakeLanguageModel(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
self.layers[-1](hidden)
return SimpleNamespace()
class _FakeQwenInnerModel(nn.Module):
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.language_model = _FakeLanguageModel(hidden_size)
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
return self.language_model(input_ids)
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
self.model = _FakeQwenInnerModel(hidden_size)
@property
def device(self) -> torch.device:
return self.weight.device
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden_size = self.config.hidden_size
values = torch.arange(
batch_size * seq_len * hidden_size,
device=input_ids.device,
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
self.model(input_ids) # call through so the forward hook on layers[-1] fires
return SimpleNamespace(hidden_states=[hidden])
class _FakeQwenInterface(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
return torch.float32 if dtype_name == "float32" else torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
def build_inputs(
self,
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, Tensor]:
batch_size = len(images)
del images, instructions, action_prompt, embodied_prompt
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
token_ids = (
[10]
+ list(range(1000, 1000 + action_count))
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
+ [11]
)
return {
"input_ids": torch.tensor(
[token_ids] * batch_size,
device=self.model.device,
dtype=torch.long,
)
}
@staticmethod
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
@property
def device(self) -> torch.device:
return self.weight.device
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
batch_size, num_frames = pixel_values_videos.shape[:2]
hidden_size = self.config.hidden_size
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
class _FakeVideoProcessor:
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
if isinstance(videos, list):
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
else:
pixel_values = torch.as_tensor(videos).unsqueeze(0)
return {"pixel_values_videos": pixel_values}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)

View File

@@ -1,157 +0,0 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers")
from conftest import (
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
set_seed_all,
) # noqa: E402
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
VLAJEPAActionHead,
)
# ---------------------------------------------------------------------------
# VLAJEPAActionHead
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4), # default test dims
(7, 0, 16), # no proprioceptive state, production-like action space
(6, 8, 8), # medium dims
],
)
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
assert t.shape == (200,)
assert torch.isfinite(t).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
timesteps = torch.randint(0, 100, (2,))
state = torch.randn(2, state_dim) if state_dim > 0 else None
out_with = head._build_inputs(conditioning, actions, state, timesteps)
out_none = head._build_inputs(conditioning, actions, None, timesteps)
assert out_with.ndim == 3 and out_none.ndim == 3
if state_dim > 0:
assert out_with.shape[1] > out_none.shape[1]
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
state = torch.randn(2, state_dim) if state_dim > 0 else None
loss = head.forward(conditioning, actions, state)
assert loss.shape == ()
assert torch.isfinite(loss) and loss > 0
def test_action_head_forward_gradient_flows() -> None:
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
loss = head.forward(conditioning, actions, state)
loss.backward()
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
@torch.no_grad()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
state = torch.randn(2, state_dim) if state_dim > 0 else None
pred = head.predict_action(conditioning, state)
assert tuple(pred.shape) == (2, action_horizon, action_dim)
assert torch.isfinite(pred).all()
# ---------------------------------------------------------------------------
# action_is_pad masking
# ---------------------------------------------------------------------------
def test_action_head_loss_fully_padded_is_zero() -> None:
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss = head.forward(conditioning, actions, state, action_is_pad)
assert loss.item() == 0.0
def test_action_head_loss_none_matches_no_padding() -> None:
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
set_seed_all(0)
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
set_seed_all(0)
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
assert torch.isclose(loss_none, loss_zeros)

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def test_delta_indices() -> None:
config = make_config()
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
assert config.action_delta_indices == list(range(ACTION_HORIZON))
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="n_action_steps"):
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
def test_too_few_video_frames_raises() -> None:
with pytest.raises(ValueError, match="video_horizon"):
VLAJEPAConfig(
chunk_size=16,
n_action_steps=16,
num_video_frames=2,
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
)
def test_validate_features_no_image_raises() -> None:
config = VLAJEPAConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
with pytest.raises(ValueError, match="at least one visual input feature"):
config.validate_features()
def test_validate_features_no_action_raises() -> None:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
},
output_features={},
)
with pytest.raises(ValueError, match="action output feature"):
config.validate_features()
def test_validate_features_sets_action_dim_from_feature() -> None:
config = make_config(action_dim=6, state_dim=10)
assert config.action_dim == 6
assert config.state_dim == 10

View File

@@ -1,598 +0,0 @@
#!/usr/bin/env python
from __future__ import annotations
import os
from copy import deepcopy
import numpy as np
import pytest
import torch
from torch import Tensor
pytest.importorskip("transformers")
pytest.importorskip("diffusers")
pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
)
from conftest import ( # noqa: E402
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
EXPECTED_ACTION_CHUNK_SHAPE,
EXPECTED_SELECT_ACTION_SHAPE,
IMAGE_SIZE,
N_ACTION_STEPS,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
make_inference_batch,
make_train_batch,
set_seed_all,
)
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
from lerobot.utils.constants import ACTION # noqa: E402
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO"
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
# ---------------------------------------------------------------------------
# Core training / inference tests
# ---------------------------------------------------------------------------
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
batch = make_train_batch()
batch_before = deepcopy(batch)
loss, logs = policy.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] > 0
assert logs["wm_loss"] >= 0
loss.backward()
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
# Batch must not be mutated.
assert set(batch) == set(batch_before)
for key, value in batch.items():
if isinstance(value, Tensor):
assert torch.equal(value, batch_before[key])
else:
assert value == batch_before[key]
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
assert torch.isfinite(loss) and loss > 0
assert set(logs) == {"action_loss", "wm_loss", "loss"}
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_training_forward_various_dims(
patch_vla_jepa_external_models: None,
action_dim: int,
state_dim: int,
action_horizon: int,
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
policy = VLAJEPAPolicy(config)
policy.train()
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
loss, _ = policy.forward(batch)
assert torch.isfinite(loss) and loss > 0
@torch.no_grad()
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert chunk.device.type == "cpu"
assert torch.isfinite(chunk).all()
a1 = policy.select_action(batch)
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
@torch.no_grad()
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
def test_action_generation_various_dims(
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim)
policy = VLAJEPAPolicy(config)
policy.eval()
batch = make_inference_batch(state_dim=state_dim)
chunk = policy.predict_action_chunk(batch)
assert chunk.shape[-1] == action_dim
assert torch.isfinite(chunk).all()
@torch.no_grad()
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
set_seed_all(123)
actions_1 = policy.predict_action_chunk(batch)
set_seed_all(123)
actions_2 = policy.predict_action_chunk(batch)
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert torch.allclose(actions_1, actions_2, atol=1e-6)
@torch.no_grad()
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
policy.eval()
for seed in [0, 42, 123]:
set_seed_all(seed)
chunk = policy.predict_action_chunk(make_inference_batch())
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
# ---------------------------------------------------------------------------
# Action queue behaviour
# ---------------------------------------------------------------------------
@torch.no_grad()
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
# First call fills the queue (n_action_steps items) and pops one.
a1 = policy.select_action(batch)
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
# Second call pops from the existing queue without calling predict_action_chunk.
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
@torch.no_grad()
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
policy.select_action(make_inference_batch())
assert len(policy._queues[ACTION]) > 0
policy.reset()
assert len(policy._queues[ACTION]) == 0
# ---------------------------------------------------------------------------
# Format conversion
# ---------------------------------------------------------------------------
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
# ---------------------------------------------------------------------------
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
return batch
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
return batch
_CP_ROOT = "lerobot"
# Each tuple: (repo_id, enable_world_model)
_HUB_VARIANTS = [
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
]
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
"""Policy loads from the converted safetensors checkpoint on the Hub."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
assert policy.config.enable_world_model == enable_world_model
assert sum(p.numel() for p in policy.parameters()) > 0
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
policy.train()
batch = _make_hub_train_batch(policy)
loss, logs = policy.forward(batch)
assert torch.isfinite(loss)
assert "action_loss" in logs
if enable_world_model:
assert "wm_loss" in logs
@extended_test
def test_hub_freeze_qwen_disables_world_model() -> None:
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
qwen_params = list(policy.model.qwen.parameters())
assert all(not p.requires_grad for p in qwen_params)
assert any(p.requires_grad for p in policy.model.action_model.parameters())
@extended_test
def test_hub_disable_world_model_loads_simpler_env() -> None:
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
assert policy.model.video_encoder is None
@extended_test
def test_hub_libero_inference_shape() -> None:
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
policy.eval()
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim
# ---------------------------------------------------------------------------
# Postprocessor unnormalization tests
#
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
# ---------------------------------------------------------------------------
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
from lerobot.utils.constants import ACTION
return {
ACTION: {
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
}
}
@torch.no_grad()
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
rng = np.random.default_rng(7)
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
actions_tensor = torch.from_numpy(actions_np)
transition = policy_action_to_transition(actions_tensor)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
# Deliberately out-of-range inputs
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
clipped = np.clip(actions_np, -1.0, 1.0)
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
clip_step = ClipActionsProcessorStep()
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
transition = policy_action_to_transition(torch.from_numpy(actions_np))
transition = clip_step(transition)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_applied_after_predict_action_chunk(
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
Verifies the split: predict_action_chunk returns normalized actions, and calling the
postprocessor on them produces the correctly unnormalized result.
"""
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
cfg = make_config()
cfg.clip_normalized_actions = False
cfg.binarize_gripper_action = False
policy = VLAJEPAPolicy(cfg)
policy.eval()
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
dataset_stats = _make_dataset_stats()
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
# predict_action_chunk returns raw (normalized) actions
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
"predict_action_chunk should return raw actions without unnormalization applied."
)
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
unnormed = postprocessor(chunk)
from lerobot.utils.constants import ACTION
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
# ---------------------------------------------------------------------------
# World-model view adjustment (padding / trimming) tests
# ---------------------------------------------------------------------------
_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests
def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig:
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
config = VLAJEPAConfig(
input_features={
**{
f"{OBS_IMAGES}.cam{i}": PolicyFeature(
type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)
)
for i in range(num_views)
},
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
device="cpu",
chunk_size=ACTION_HORIZON,
n_action_steps=N_ACTION_STEPS,
action_dim=ACTION_DIM,
state_dim=STATE_DIM,
num_video_frames=_MULTIVIEW_NUM_FRAMES,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=jepa_tubelet_size,
)
config.validate_features()
return config
def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict:
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
batch = {
f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE)
for i in range(num_views)
}
batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM)
batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM)
batch["task"] = ["pick up the cube"] * batch_size
return batch
@pytest.mark.parametrize(
"num_views",
[
1, # fewer views than jepa_tubelet_size → first view duplicated
2, # exact match → unchanged
3, # more views than jepa_tubelet_size → trimmed to first two
],
)
def test_training_forward_world_model_view_adjustment(
patch_vla_jepa_external_models: None,
num_views: int,
) -> None:
"""World-model view padding/trimming must not break the training forward pass."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2))
policy.train()
loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views))
assert torch.isfinite(loss)
assert logs["wm_loss"] >= 0
def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None:
"""With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2))
policy.train()
captured_videos: list = []
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=1))
# reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …)
assert len(captured_videos) == BATCH_SIZE * 2
for i in range(BATCH_SIZE):
np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1])
def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None:
"""With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2))
policy.train()
captured_videos: list = []
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=3))
# Only B*2 items must reach the encoder, not B*3.
assert len(captured_videos) == BATCH_SIZE * 2

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
from lerobot.policies.vla_jepa.world_model import (
ActionConditionedVideoPredictor,
)
_ACTION_EMBED_DIM = 8
def _make_predictor(
embed_dim: int = 8,
action_embed_dim: int = _ACTION_EMBED_DIM,
predictor_embed_dim: int = 24,
num_action_tokens: int = 2,
tokens_per_frame: int = 1,
) -> ActionConditionedVideoPredictor:
return ActionConditionedVideoPredictor(
num_frames=1,
img_size=(1, tokens_per_frame),
patch_size=1,
tubelet_size=1,
embed_dim=embed_dim,
action_embed_dim=action_embed_dim,
predictor_embed_dim=predictor_embed_dim,
depth=1,
num_heads=2,
mlp_ratio=2.0,
num_action_tokens_per_step=num_action_tokens,
)
@pytest.mark.parametrize(
"batch,num_steps,tokens_per_frame,embed_dim",
[
(1, 2, 1, 8),
(2, 3, 4, 8),
(4, 5, 2, 16),
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
)
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
assert torch.isfinite(out).all()
def test_predictor_step_mismatch_raises() -> None:
predictor = _make_predictor(tokens_per_frame=4)
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
with pytest.raises(RuntimeError):
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch

View File

@@ -1,155 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
IMAGE_KEYS,
assert_processor_inputs_match_lerobot,
clone_batch,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
)
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = torch.device("cpu")
DUMMY_DATASET_STATS = {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
key: {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
}
for key in IMAGE_KEYS
},
}
class PI05PolicyInputAdapter(torch.nn.Module):
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
_preprocess_images = PI05Policy._preprocess_images
def __init__(self, config: PI05Config) -> None:
super().__init__()
self.config = config
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
def create_pi05_config() -> PI05Config:
config = PI05Config(device=str(DEVICE))
config.max_state_dim = DUMMY_STATE_DIM
config.max_action_dim = DUMMY_ACTION_DIM
config.chunk_size = DUMMY_ACTION_HORIZON
config.n_action_steps = DUMMY_ACTION_HORIZON
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
**{
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
for key in IMAGE_KEYS
},
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
}
return config
def create_dummy_data() -> dict:
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
**{
f"observation.images.{key}": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
)
for key in IMAGE_KEYS
},
"task": [prompt for _ in range(batch_size)],
}
def test_pi05_processor_inputs_match_openpi_reference():
torch.manual_seed(0)
config = create_pi05_config()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
raw_batch = create_dummy_data()
lerobot_batch = preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
PI05PolicyInputAdapter(config),
lerobot_batch,
openpi_observation,
compare_state=False,
)
torch.testing.assert_close(
lerobot_batch[ACTION],
openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
),
rtol=0,
atol=0,
)

View File

@@ -1,156 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi0 import PI0Policy # noqa: E402
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
IMAGE_KEYS,
assert_processor_inputs_match_lerobot,
clone_batch,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
)
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48
DEVICE = torch.device("cpu")
DUMMY_DATASET_STATS = {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
key: {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
}
for key in IMAGE_KEYS
},
}
class PI0PolicyInputAdapter(torch.nn.Module):
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
_preprocess_images = PI0Policy._preprocess_images
prepare_state = PI0Policy.prepare_state
def __init__(self, config: PI0Config) -> None:
super().__init__()
self.config = config
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
def create_pi0_config() -> PI0Config:
config = PI0Config(device=str(DEVICE))
config.max_state_dim = DUMMY_STATE_DIM
config.max_action_dim = DUMMY_ACTION_DIM
config.chunk_size = DUMMY_ACTION_HORIZON
config.n_action_steps = DUMMY_ACTION_HORIZON
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
**{
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
for key in IMAGE_KEYS
},
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
}
return config
def create_dummy_data() -> dict:
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
**{
f"observation.images.{key}": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
)
for key in IMAGE_KEYS
},
"task": [prompt for _ in range(batch_size)],
}
def test_pi0_processor_inputs_match_openpi_reference():
torch.manual_seed(0)
config = create_pi0_config()
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
raw_batch = create_dummy_data()
lerobot_batch = preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
)
assert_processor_inputs_match_lerobot(
PI0PolicyInputAdapter(config),
lerobot_batch,
openpi_observation,
compare_state=True,
)
torch.testing.assert_close(
lerobot_batch[ACTION],
openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=False,
),
rtol=0,
atol=0,
)

View File

@@ -1,14 +1,10 @@
"""Tests for policy.path support in YAML config files (issue #2957)."""
import json
import sys
import tempfile
from dataclasses import dataclass, field
from unittest.mock import patch
import yaml
from lerobot.configs import parser
from lerobot.configs.parser import (
_config_path_args,
_config_yaml_overrides,
@@ -20,8 +16,7 @@ from lerobot.configs.parser import (
def test_extract_path_fields_from_yaml():
"""Test that policy.path is extracted from a YAML config and the policy block
is removed entirely (siblings are captured separately as cli_overrides)."""
"""Test that policy.path is extracted from a YAML config and removed."""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
@@ -31,33 +26,26 @@ def test_extract_path_fields_from_yaml():
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
# Path should be extracted and stored
assert _config_path_args["policy"] == "lerobot/smolvla_base"
# Cleaned config should not have the policy block at all -- draccus must not
# try to decode it as PreTrainedConfig; the actual config comes from
# from_pretrained(path) with the captured overrides applied on top.
# Cleaned config should not have the path field
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f)
assert "policy" not in cleaned
assert "path" not in cleaned["policy"]
assert cleaned["policy"]["type"] == "smolvla"
assert cleaned["policy"]["push_to_hub"] is False
# Original dataset should be untouched
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
# Sibling overrides (excluding type/path) captured for from_pretrained.
overrides = get_yaml_overrides("policy")
assert any("push_to_hub=false" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_path_fields_from_json():
"""Test that policy.path is extracted from a JSON config and the policy
block is removed entirely."""
"""Test that policy.path is extracted from a JSON config."""
config = {
"policy": {"type": "act", "path": "some/local/path"},
}
@@ -66,17 +54,15 @@ def test_extract_path_fields_from_json():
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
assert _config_path_args["policy"] == "some/local/path"
with open(cleaned_path) as f:
cleaned = json.load(f)
assert "policy" not in cleaned
assert "path" not in cleaned["policy"]
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_no_path_returns_original():
@@ -230,91 +216,3 @@ def test_flatten_nested_with_bools():
args = _flatten_to_cli_args(d)
assert "--optimizer.use_warmup=true" in args
assert "--optimizer.lr=0.01" in args
def test_extract_removes_field_with_siblings_and_no_type():
"""Regression: when policy.path has siblings but no type:, the entire policy
block must still be removed from the cleaned config. Otherwise draccus tries
to decode the leftover dict as PreTrainedConfig and crashes on the missing
type discriminator.
"""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {
"path": "lerobot/smolvla_base",
"n_action_steps": 10,
"dtype": "bfloat16",
},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f) or {}
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
assert _config_path_args["policy"] == "lerobot/smolvla_base"
overrides = get_yaml_overrides("policy")
assert any("n_action_steps=10" in o for o in overrides)
assert any("dtype=bfloat16" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
@dataclass
class _DummyNested:
foo: int = 0
@dataclass
class _DummyConfig:
nested: _DummyNested = field(default_factory=_DummyNested)
other: str = "default"
@classmethod
def __get_path_fields__(cls):
return ["nested"]
def test_wrap_uses_cleaned_config_for_draccus_parse():
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
file but must propagate that to the draccus.parse fallback branch. Without
the fix, cli_args still contains --config_path=<original> and draccus reads
the original YAML with `path:` still in it, crashing on the unknown field.
"""
config = {
"nested": {"path": "some/checkpoint", "foo": 42},
"other": "set-via-yaml",
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
captured: dict = {}
@parser.wrap()
def main(cfg: _DummyConfig) -> _DummyConfig:
captured["cfg"] = cfg
return cfg
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
main()
assert captured["cfg"].other == "set-via-yaml"
assert _config_path_args["nested"] == "some/checkpoint"
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
# class (a real PreTrainedConfig would now load the checkpoint and apply
# the captured yaml_overrides via from_pretrained()).
assert captured["cfg"].nested.foo == 0
_config_path_args.clear()
_config_yaml_overrides.clear()

42
uv.lock generated
View File

@@ -2915,11 +2915,6 @@ metaworld = [
{ name = "scipy" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
]
molmoact2 = [
{ name = "peft" },
{ name = "scipy" },
{ name = "transformers" },
]
motorbridge-dep = [
{ name = "motorbridge" },
]
@@ -3047,11 +3042,6 @@ video-benchmark = [
viz = [
{ name = "rerun-sdk" },
]
vla-jepa = [
{ name = "diffusers" },
{ name = "qwen-vl-utils" },
{ name = "transformers" },
]
wallx = [
{ name = "peft" },
{ name = "qwen-vl-utils" },
@@ -3120,7 +3110,6 @@ requires-dist = [
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
@@ -3142,7 +3131,6 @@ requires-dist = [
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["molmoact2"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["motorbridge-dep"], marker = "extra == 'rebot'" },
{ name = "lerobot", extras = ["motorbridge-smart-servo-dep"], marker = "extra == 'rebot'" },
{ name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" },
@@ -3150,7 +3138,6 @@ requires-dist = [
{ name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["peft"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'molmoact2'" },
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" },
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["phone"], marker = "extra == 'all'" },
@@ -3170,7 +3157,6 @@ requires-dist = [
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
@@ -3179,7 +3165,6 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'aloha'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'libero'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'metaworld'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'molmoact2'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'phone'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'pi'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
@@ -3191,21 +3176,18 @@ requires-dist = [
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
@@ -3226,7 +3208,7 @@ requires-dist = [
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
@@ -3267,7 +3249,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"
@@ -4615,7 +4597,7 @@ wheels = [
[[package]]
name = "placo"
version = "0.9.15"
version = "0.9.16"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cmeel" },
@@ -4625,16 +4607,16 @@ dependencies = [
{ name = "pin" },
{ name = "rhoban-cmeel-jsoncpp" },
]
sdist = { url = "https://files.pythonhosted.org/packages/40/c4/a33a0ee2ad798471a1c43a96109d28f358fd95c78a56f8cff57acb66d2bc/placo-0.9.15.tar.gz", hash = "sha256:df47f1154bae305c943bd20ba4f56d50ffc65625efc98679fefb11e8ff3c462c", size = 136856, upload-time = "2025-11-03T10:49:13.151Z" }
sdist = { url = "https://files.pythonhosted.org/packages/9e/0a/36c5b729d0d69075e7dfafd1b36c4df6fbb8c1ff1585e88d3c56d4c15010/placo-0.9.16.tar.gz", hash = "sha256:5314faaf6442e7ffe17347680d236af953951813bbfb1c09c4a27f7388d332e4", size = 136871, upload-time = "2025-11-07T14:24:58.811Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ef/03/207b1c087996b918fdbaa5a3a685e3b14b068cd303bf87affdf83f722b33/placo-0.9.15-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eab7a299e73291fe631c02448b9e9826539f4824e198bcf85f7c91fdd77d054b", size = 1641975, upload-time = "2025-11-03T10:48:48.887Z" },
{ url = "https://files.pythonhosted.org/packages/92/55/40432b26bb1c5b9e677fbc41e8d85b54fa8897b7daebb2a22d410b0a7f7b/placo-0.9.15-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23f9dd19b8d15fa9d86968948b57981ebc6f1decafeffc2d646d8b56f685b50d", size = 1515448, upload-time = "2025-11-03T10:48:50.562Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8e/e6283201d329409dccf2045b5c1efd73b3dad5268143bbea4668029ca9c6/placo-0.9.15-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2680a2166c23a0a2aa6226ad75c63a2b2310c812673a5db296616d9af053e076", size = 2106550, upload-time = "2025-11-03T10:48:52.364Z" },
{ url = "https://files.pythonhosted.org/packages/51/c3/77efe4c999e1d80ec14879ef73ea2a2144aa12db2b67870a562f87ed5b43/placo-0.9.15-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:1a2202a78bcd2874ca09a9a6526a95b38874803923cb9b3b4b96cd68ab4b7217", size = 2178531, upload-time = "2025-11-03T10:48:53.932Z" },
{ url = "https://files.pythonhosted.org/packages/fe/e7/b5cc5ad53414ff7af3357e0c9d97d902a3ce276e7810f8814fe9f0c1fb70/placo-0.9.15-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:84a445a99b059a512d1b4c64841a91d6f50149c7be9255c65bedeebbe6663989", size = 1641982, upload-time = "2025-11-03T10:48:55.277Z" },
{ url = "https://files.pythonhosted.org/packages/ad/1c/1c9163d941698a077617f218041efc573d3bf5a1c169a284112bd622fccd/placo-0.9.15-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b3106e7e6b05cbfa494239d8aa14795f7da8ee5dec851602f0d6297e311d7334", size = 1515447, upload-time = "2025-11-03T10:48:56.975Z" },
{ url = "https://files.pythonhosted.org/packages/cd/22/3d9b9045b89248c8476dd42243bc9821a123d9199e4e96a944124ad80cf1/placo-0.9.15-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:66c3d099e87551401aace04f1293a3c3563b1399319976647846845bf92c3ccf", size = 2106558, upload-time = "2025-11-03T10:48:58.667Z" },
{ url = "https://files.pythonhosted.org/packages/20/0b/45dbdd2c378a7cece578b7344fda493d5a2aa6777089798a315ce4f97c22/placo-0.9.15-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0e06b7d3d618ddc2b649ab8b0b46db8001fe72fe2fbcc801524df0ccc8a3da40", size = 2178531, upload-time = "2025-11-03T10:49:00.533Z" },
{ url = "https://files.pythonhosted.org/packages/a4/95/8a85b58033303fd354a680e1494f47801abdca9133c222ae1c2473983f25/placo-0.9.16-0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:417a89920b340e3aec19f1f49e1fb06789c679a807450157af8bdf4aef4bc82b", size = 1641806, upload-time = "2025-11-07T14:24:34.736Z" },
{ url = "https://files.pythonhosted.org/packages/92/bd/2fb3556c71b0689b3168c0e85fce5befb605affcfe4afb3b5e7b5ba6749f/placo-0.9.16-0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a7ef7ac33ba889d2122db0d7ed55eeecdffed020e2282712989bb11e408bab40", size = 1515468, upload-time = "2025-11-07T14:24:36.587Z" },
{ url = "https://files.pythonhosted.org/packages/ea/fd/7dba380720dfb89df582a51d0b2cb43957a36849f676baa3dfc74704e67f/placo-0.9.16-0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:885773fe8a8e809022451ec16d47479562a042596f663b8c5bbe762cd616f573", size = 2106540, upload-time = "2025-11-07T14:24:38.149Z" },
{ url = "https://files.pythonhosted.org/packages/7a/40/97c7c799fe4f89111b973d7a5f86626a2ec1d0e6e20ce2988e0a2bda66f5/placo-0.9.16-0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:19f097305c714e539fbf19e761897f6daab2ff73f639319431b144e77dd3852e", size = 2178511, upload-time = "2025-11-07T14:24:40.04Z" },
{ url = "https://files.pythonhosted.org/packages/f7/4d/f1700aae269584477b5d72561d2fc5ace37b4bca167892a74a369849c67e/placo-0.9.16-0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:be11fa987702114097ccf3d94e1c4a891796878429e25c8d88b187ecc652e7ae", size = 1641812, upload-time = "2025-11-07T14:24:41.308Z" },
{ url = "https://files.pythonhosted.org/packages/43/d7/21d1d0dd1311c0cbd9ccd233cdae520bbe2370095e3c831059d6077c90bd/placo-0.9.16-0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c2d65aeb4844eae28006ad3a50c8519b27c701912cc99c46c95e33ed049f3635", size = 1515457, upload-time = "2025-11-07T14:24:42.758Z" },
{ url = "https://files.pythonhosted.org/packages/0f/e8/939ba23bfa539fb90ab9ab1c2c59ff9a9a46e24699fc90e8ca3ff2948646/placo-0.9.16-0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a7633aff1c592c1f45e86a174a372d5d7972673935cb9151391277ff49ec2072", size = 2106538, upload-time = "2025-11-07T14:24:44.517Z" },
{ url = "https://files.pythonhosted.org/packages/08/00/ad24cc0ad85fbe12267df28c2061e1eaef8f852146c467fcd7a681e11028/placo-0.9.16-0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0d97a7284b65fc45aef27865c80cf7e53f04646d35bb18494ab62dfbbc9a35bd", size = 2178514, upload-time = "2025-11-07T14:24:45.994Z" },
]
[[package]]