Compare commits

..

4 Commits

Author SHA1 Message Date
Khalil Meftah
015c88cf0d Frame count is now derived from the upstream .npy length 2026-05-18 10:57:16 +02:00
Khalil Meftah
0164725af8 fix decord 2026-05-18 10:39:51 +02:00
Khalil Meftah
34274c6f70 scripts: add Robometer parity checks (upstream example videos + LIBERO) 2026-05-17 15:41:31 +02:00
Khalil Meftah
f6a13b1338 Add Robometer reward model 2026-05-17 14:59:23 +02:00
191 changed files with 4980 additions and 28240 deletions

View File

@@ -3,8 +3,6 @@
title: LeRobot
- local: installation
title: Installation
- local: cheat-sheet
title: Cheat sheet
title: Get started
- sections:
- local: il_robots
@@ -39,12 +37,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: language_and_recipes
title: Language Columns and Recipes
- local: tools
title: Tools
- local: video_encoding_parameters
title: Video encoding parameters
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
@@ -59,10 +53,6 @@
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: vla_jepa
title: VLA-JEPA
- local: eo1
title: EO-1
- local: groot
@@ -77,8 +67,6 @@
- sections:
- local: sarm
title: SARM
- local: topreward
title: TOPReward
title: "Reward Models"
- sections:
- local: inference
@@ -151,8 +139,6 @@
title: OMX
- local: openarm
title: OpenArm
- local: rebot_b601
title: reBot B601-DM
title: "Robots"
- sections:
- local: phone_teleop

View File

@@ -79,13 +79,17 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/act_policy \
--robot.type=so101_follower \
lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_robot \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true \
--task="Your task description" \ # can be skipped for ACT
--duration=60
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=${HF_USER}/act_policy
```

View File

@@ -1,139 +0,0 @@
# Cheat sheet
All of the LeRobot commands in one place. If you forgot how to use a specific command or want to learn about a new one you can do it here.
> [!WARNING]
> For all of the commands listed below remember to change the ports/names/ids to your own values!
> [!TIP]
> Another great way to look at all the commands and get them configured for your specific setup is to use this [Jupyter Notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb).
### Setup and installation
For installation please look at [LeRobot Installation](https://huggingface.co/docs/lerobot/main/en/installation).
### Useful tools
###### Find port
Use this to identify which serial ports your robots are connected to. Follow the instructions in your terminal: you will be asked to unplug the USB cable and press Enter. The script will then detect and print the correct serial port for that robot.
```bash
lerobot-find-port
```
###### Find cameras
Quickly find camera indices and verify their output. This command prints camera information to the terminal and saves test frames from each detected camera to `lerobot/outputs/captured_images`
```bash
lerobot-find-cameras
```
### Calibration
In most cases you will need to perform calibration just once for each robot and teleoperation device. Before performing the calibration make sure that all the joints are roughly in the middle position.
```bash
lerobot-calibrate \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_follower_arm
```
Make sure that you use the same IDs used during calibration later for the other scripts. That's how LeRobot finds the calibration files.
### Teleoperation
Teleoperating with two cameras and displaying the data with Rerun.
```bash
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_follower_arm \
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--teleop.id=my_leader_arm \
--display_data=true
```
### Recording a dataset
The dataset is automatically uploaded to the server and saved under repo_id, make sure you are logged in to your HF account with CLI:
`hf auth login`
You can get the token from: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
```bash
lerobot-record \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_follower_arm \
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--teleop.id=my_leader_arm \
--dataset.repo_id=${HF_USER}/so101_dataset_test \
--dataset.num_episodes=30 \
--dataset.single_task="put the red brick in a bowl" \
--dataset.streaming_encoding=true \
--display_data=true
```
While collecting the dataset you can control the process with your keyboard:
Control the data recording flow using keyboard shortcuts:
- Press **Right Arrow (`→`)**: Save episode and move to the next.
- Press **Left Arrow (`←`)**: Delete current episode and retry.
- Press **Escape (`ESC`)**: Stop, encode videos, and upload.
### Training
Depending on your hardware training the policy might take a few hours. That's how you train simple `ACT` policy:
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/so101_dataset_test \
--policy.type=act \
--output_dir=outputs/train/act_so101_test \
--job_name=act_so101_test \
--policy.device=cuda \
--wandb.enable=true \
--policy.repo_id=${HF_USER}/policy_test \
--steps=20000
```
- Policy Types: `act`, `diffusion`, `smolvla`, `pi05`
- Devices: `cuda` (NVIDIA), `mps` (Apple Silicon), `cpu`
If you want to fine-tune a specific model you can provide the path to the model. In this case path is enough and type can be skipped.
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/so101_dataset_test \
--policy.path=username/the_policy_to_finetune \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/policy_test \
--output_dir=outputs/train/act_so101_test \
--steps=20000
```
### Inference
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
> [!TIP]
> If you are using the previous release V0.5.1 instead of `lerobot-rollout` you need to use `lerobot-record`. More information [here](https://huggingface.co/docs/lerobot/v0.5.1/en/il_robots#run-inference-and-evaluate-your-policy).
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the transparent box" \
--duration=60
```

View File

@@ -0,0 +1,277 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation

View File

@@ -194,7 +194,7 @@ lerobot-record \
--dataset.single_task="Navigate around obstacles" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--display_data=true
```

View File

@@ -105,12 +105,10 @@ These results demonstrate GR00T's strong generalization capabilities across dive
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
```bash
lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
lerobot-record \
--robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
@@ -121,12 +119,14 @@ lerobot-rollout\
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--duration=600
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
```
## License

View File

@@ -232,7 +232,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--display_data=true
```
@@ -278,6 +278,6 @@ lerobot-record \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
```

View File

@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
port="/dev/tty.usbmodem58760431541",
id="my_red_robot_arm",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
)
robot = SO101Follower(robot_config)
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoption id="Command">
```bash
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem5AB90687491 \
--robot.id=my_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem5AB90689011 \
--teleop.id=my_leader_arm \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=my_awesome_leader_arm \
--display_data=true
```
</hfoption>
@@ -122,48 +122,34 @@ lerobot-teleoperate \
<!-- prettier-ignore-start -->
```python
import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
camera_config = {
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
}
robot_config = KochFollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="my_red_robot_arm",
cameras=camera_config
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
teleop_config = KochLeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
)
init_rerun(session_name="teleoperation")
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
robot = KochFollower(robot_config)
teleop_device = KochLeader(teleop_config)
robot.connect()
teleop_device.connect()
TARGET_HZ = 30
TIME_PER_FRAME = 1.0 / TARGET_HZ
while True:
start_time = time.perf_counter()
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
if sleep_time > 0:
time.sleep(sleep_time)
```
<!-- prettier-ignore-end -->
@@ -207,7 +193,7 @@ lerobot-record \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
</hfoption>
@@ -216,11 +202,10 @@ lerobot-record \
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
@@ -233,56 +218,71 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
def main():
# Create robot configuration
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
# Create robot configuration
robot_config = SO100FollowerConfig(
id="my_awesome_follower_arm",
cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
},
port="/dev/tty.usbmodem58760434471",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
teleop_config = SO100LeaderConfig(
id="my_awesome_leader_arm",
port="/dev/tty.usbmodem585A0077581",
)
# Initialize the robot and teleoperator
robot = SO101Follower(robot_config)
teleop = SO101Leader(teleop_config)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop = SO100Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
@@ -291,50 +291,26 @@ def main():
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
episode_idx += 1
dataset.save_episode()
episode_idx += 1
# finalize dataset
log_say("Finalizing dataset...")
dataset.finalize()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
if __name__ == "__main__":
main()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
```
<!-- prettier-ignore-end -->
@@ -372,7 +348,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording.
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
@@ -446,7 +422,7 @@ from lerobot.utils.utils import log_say
episode_idx = 0
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
robot = SO100Follower(robot_config)
robot.connect()
@@ -514,83 +490,6 @@ Additionally you can provide extra `tags` or specify a `license` for your model
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Train using Hugging Face Jobs
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
To run the training use this command:
<hfoptions id="train_with_hf_jobs">
<hfoption id="Command">
```bash
hf jobs run \
--flavor a10g-small \
--timeout 4h \
--secrets HF_TOKEN \
huggingface/lerobot-gpu:latest \
-- \
python -m lerobot.scripts.lerobot_train \
--dataset.repo_id=username/dataset \
--policy.type=act \
--steps=5000 \
--batch_size=16 \
--policy.device=cuda \
--policy.repo_id=username/your_policy \
--log_freq=100
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from huggingface_hub import run_job, get_token
run_name = "act_so101_hf_jobs"
dataset_id = "username/dataset"
user_hub_id = "username"
command_args = [
"python", "-m", "lerobot.scripts.lerobot_train",
"--dataset.repo_id", dataset_id,
"--policy.type", "act",
"--steps", "5000",
"--batch_size", "16",
"--num_workers", "4",
"--policy.device", "cuda",
"--log_freq", "100",
"--save_freq", "1000",
"--save_checkpoint", "true",
"--wandb.enable", "false",
"--policy.repo_id", f"{user_hub_id}/{run_name}"
]
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
job_info = run_job(
image="huggingface/lerobot-gpu:latest",
command=command_args,
flavor="a10g-small",
timeout="4h",
secrets={"HF_TOKEN": get_token()}
)
print("\n🚀 Job successfully launched!")
print(f"🔹 Job ID: {job_info.id}")
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
For longer training sessions increase the timeout.
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:

View File

@@ -1,147 +0,0 @@
# Language columns and recipes
Most LeRobot datasets ship with a single `task` string per episode — fine for
short, single-instruction skills, but not enough for the longer-horizon,
multi-modal robot policies the field is moving toward (high-level planning,
memory, interjections, VQA, tool use). To support those policies without
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
language columns and a small recipe layer that turns those rows into
chat-style training samples on the fly.
The design splits cleanly into three layers:
1. **Data in the dataset** — language annotations stored next to frames in
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
and `language_events`). Datasets without these columns keep their existing
behavior.
2. **Recipe** — a YAML file that declares which annotation rows to bind and
how to lay them out as chat turns (`role`, `content`, optional images,
optional tool calls). Recipes are pure config; no Python required to add a
new one.
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
recipe against the per-frame annotations and emits HF-style `messages` plus
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
that policy processors consume.
This page describes each layer in turn.
## Layer 1 — language columns in the dataset
The two optional columns live next to frame data in
`data/chunk-*/file-*.parquet`:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape (event rows omit `timestamp` because the
frame the row sits on already provides it):
```text
role: string
content: string | null
style: string | null
timestamp: float32 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null
```
The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
the matching `observation.images.*` feature key. Rows of every other style —
including `motion`, which describes robot-frame primitives in joint / Cartesian
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
### Architecture
The language stack itself has three internal modules backing layer 1:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
## Layer 2 — recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
declare which annotation rows to pull (via `bindings`) and how to compose them
into chat turns (`messages`).
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
time, exactly one branch is selected deterministically from the sample index,
so different frames train different objectives (e.g. memory updates vs.
low-level execution vs. VQA) without any Python wiring.
### Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
### View-dependent resolution
For view-dependent styles (`vqa` and `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two
matches and raise an ambiguity error. Recipes consume each camera through its
own binding plus a matching image block, e.g.
```yaml
ask_vqa_top:
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- { type: image, feature: observation.images.top }
- { type: text, text: "${vqa_query}" }
- {
role: assistant,
content: "${vqa}",
stream: high_level,
target: true,
if_present: vqa,
}
```
Add one such sub-recipe per camera the dataset records.
## Layer 3 — training format
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.

View File

@@ -10,7 +10,6 @@ This docs will guide you to:
- Stream datasets without downloading using `StreamingLeRobotDataset`
- Apply image transforms for data augmentation during training
- Migrate existing `v2.1` datasets to `v3.0`
- Experiment with other `LeRobotDataset` formats and implementations like Lance
## Whats new in `v3`
@@ -44,7 +43,7 @@ lerobot-record \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
@@ -316,39 +315,3 @@ Dataset v3.0 uses incremental parquet writing with buffered metadata for efficie
- Ensures the dataset is valid for loading
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
## Other formats and implementations
### Lance
Lance is a useful format for multimodal AI datasets, especially for large-scale training requiring high performance IO and random access.
The `lerobot-lancedb` package implements `LeRobotLanceDataset` (for JPEG images) and `LeRobotLanceVideoDataset` (for mp4 videos).
Those two storage layouts both subclass LeRobotDataset and can provide data loading speed ups.
`LeRobotLanceDataset` is a drop-in replacement for `LeRobotDataset`:
```python
from lerobot.datasets import LeRobotDatasetMetadata
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot_lancedb import LeRobotLanceDataset, LeRobotLanceVideoDataset
cfg = DiffusionConfig(...)
meta = LeRobotDatasetMetadata(root=local_dataset_path) # or use repo_id=... to load metadata from the Hub
delta_timestamps = {...}
# Use LeRobotLanceDataset for image datasets
dataset = LeRobotLanceDataset(
root=local_dataset_path, # or use repo_id=... to stream from the Hub
delta_timestamps=delta_timestamps,
return_uint8=True,
)
# Or use LeRobotLanceVideoDataset for video datasets:
dataset = LeRobotLanceVideoDataset(
root=local_dataset_path, # or use repo_id=... to stream from the Hub
delta_timestamps=delta_timestamps,
return_uint8=True,
)
```
Join the discussion on [Github](https://github.com/huggingface/lerobot/issues/3608) and explore the `lerobot-lancedb` documentation [here](https://lancedb.github.io/lerobot-lancedb/).

View File

@@ -1,433 +0,0 @@
# MolmoAct2 Policy
MolmoAct2 is the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
training, evaluation, checkpointing, and dataset interfaces for easier use with
LeRobot datasets.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## Installation Requirements
Install LeRobot with the MolmoAct2 optional dependencies:
```bash
pip install -e ".[molmoact2]"
```
To run the models in this repository, you need an NVIDIA GPU. The measurements
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
`gradient_checkpointing=true` and include the forward pass, backward pass,
gradient clipping, optimizer step, and optimizer state allocation. Values are
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
dataloader workers, CUDA context, and fragmentation.
Multi-GPU training through `accelerate` increases throughput and global batch
size, but this LeRobot port does not currently expose the original MolmoAct2
`fsdp_devices` model-parallel training path. The current training script has
not been tested for multi-node training.
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
The repo has been tested with Ubuntu 22.04.
## Usage
To use MolmoAct2 in a LeRobot training config, set:
```python
policy.type=molmoact2
```
## Training
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
the same LeRobot training loop, dataset transforms, checkpoint saving, and
logging. The difference is only how the initial policy weights and processor
state are loaded.
### Training With Original MolmoAct2 Weight
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
the original HF model files, then build its own policy processor from the
dataset metadata and the policy options below.
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
augmentation, and LeRobot's checkpointing/logging path.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.setup_type="single franka robotic arm in libero" \
--policy.control_mode="delta end-effector pose" \
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--policy.freeze_embedding=true \
--policy.normalize_gripper=false \
--policy.enable_knowledge_insulation=false \
--policy.push_to_hub=false \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Training With LeRobot MolmoAct2 Weight
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
restores the saved LeRobot policy config, model weights, processor, and
normalization statistics. You can still override training-time options such as
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.path=/path/to/pretrained_model \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Common Practices
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
or a real-world dataset with less than 200 demonstrations, a global batch size of
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
cases, we intentionally keep the action expert fully trainable, which we found
to be crucial for model performance. For larger fine-tuning datasets, larger
global batch sizes and full fine-tuning are usually preferred.
### Common Policy Options
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
Use this for released MolmoAct2 weights.
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
created by LeRobot training.
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
`both`. `both` trains the flow-matching action expert and the discrete
action-token loss.
- `policy.train_action_expert_only`: trains only parameters whose names contain
`action_expert`. It requires `policy.action_mode=continuous`.
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
expert linear layers. When `policy.enable_lora_action_expert=false`, the
action expert base weights remain fully trainable while the VLM is trained
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
action expert is also adapter-tuned instead of fully fine-tuned.
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
context K/V states before the action loss. The default is `false`.
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
`10`. This LeRobot port overrides the loaded checkpoint's
`max_action_horizon` with this value.
- `policy.n_action_steps`: number of actions consumed from each predicted
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
- `policy.setup_type`: text inserted into the prompt to describe the robot and
scene, e.g. `single franka robotic arm in libero`. More examples are listed
in the `metadata_by_tag` entries of
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
- `policy.control_mode`: text inserted into the prompt to describe the action
space, e.g. `delta end-effector pose` or `absolute joint pose`.
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
processor.
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
example during training. We use `8` for fine-tuning.
- `policy.num_inference_steps`: optional override for continuous action
generation steps at inference time.
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
to reduce activation memory.
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
- `policy.normalize_gripper`: controls whether gripper dimensions are included
in state/action quantile normalization. The default is `false`.
- `policy.normalize_language`: normalizes task strings before prompt
construction. The default is `true`.
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
Released checkpoints use `policy.expected_max_action_dim=32`.
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
infer it from images, state dimension, action dimension, action horizon, and
discrete-action mode.
### Learning Rates
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
fine-tuning experiments.
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
`policy.optimizer_vit_lr=5e-6` for the vision tower,
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
`policy.enable_lora_action_expert=false`, so the action expert is still fully
fine-tuned with `policy.optimizer_action_expert_lr`. If
`policy.enable_lora_action_expert=true`, the action expert is trained through
LoRA adapters instead.
- Action-expert-only fine-tuning trains only the action expert and uses
`policy.optimizer_action_expert_lr=5e-5`.
You can override the full fine-tuning and action-expert learning rates with
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
### Dataset Quantile Statistics
MolmoAct2 defaults to quantile normalization for state and action features. If
your dataset has not been converted with quantile statistics, you can add them
with:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset
```
Alternatively, train MolmoAct2 with mean/std normalization:
```bash
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
```
## Evaluation
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
fixed and use `policy.per_episode_seed=true`.
**Important:** We found that `num_steps_wait=10` does not reliably let the
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
results reported here use `num_steps_wait=50`.
### Evaluation With LeRobot MolmoAct2 Weight
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
normalization statistics are restored together with the model.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
--policy.inference_action_mode=continuous \
--policy.model_dtype=bfloat16 \
--policy.use_amp=true \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
### Evaluation With Original MolmoAct2 Weight
You can evaluate a released Hugging Face checkpoint directly without first
converting it to a LeRobot checkpoint. In this case, set
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
normalization statistics, action horizon, prompt metadata, and image-key order
from the checkpoint's `norm_stats.json`.
To fully replicate the MolmoAct2 paper results with released Hugging Face
checkpoints, we recommend using the v0.5.1-pinned
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
branch. That branch matches the original evaluation settings used for the
reported numbers.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.norm_tag=libero \
--policy.inference_action_mode=continuous \
--policy.model_dtype=float32 \
--policy.use_amp=false \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_goal \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
full LIBERO suite. The same command works for other released MolmoAct2
checkpoints as long as the requested `policy.norm_tag` exists in that
checkpoint's `norm_stats.json`.
### Common Evaluation Options
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
flow-matching inference or `discrete` for action-token inference. It must be
compatible with the training-time `policy.action_mode` saved in the
checkpoint.
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
saved by LeRobot.
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
image-key order, and action horizon from the original checkpoint's
`norm_stats.json`. It is required for direct original-HF checkpoint
evaluation.
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
- `policy.use_amp`: runs the policy forward under autocast during eval. For
`model_dtype=bfloat16`, keep this enabled.
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
graph path for faster repeated continuous-action rollout.
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
action generation deterministic per episode for replication.
- `env.task`: comma-separated LIBERO suites or a single suite. Use
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
by the policy processor.
## Performance Results
### LIBERO Benchmark Results
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
compare and test its LeRobot implementation, we fine-tuned
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
results.
The LeRobot fine-tuned checkpoint reported here is available at
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
and was trained on
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
| -------------- | ---------------------: | -----------------: |
| LIBERO Spatial | 98.4% | 97.8% |
| LIBERO Object | 100.0% | 100.0% |
| LIBERO Goal | 98.0% | 97.8% |
| LIBERO 10 | 96.6% | 93.2% |
| Average | 98.25% | 97.20% |
These results demonstrate MolmoAct2's strong performance across diverse robotic
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
evaluation section.
## Differences From the Original Implementation
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
differences from the original training repository are:
- The original paper training stack loads the model in fp32 and trains under
mixed precision. This LeRobot port usually loads the checkpoint directly in
`policy.model_dtype=bfloat16` for lower memory use.
- The original repository uses its own FSDP/model-parallel training path. The
LeRobot port uses the standard LeRobot/Accelerate training path and has not
been tested for multi-node training.
- The original repository supports sequence packing. The LeRobot port trains on
one LeRobot sample per item and pads to an inferred fixed sequence budget.
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
dataset transforms, image augmentation, and Weights & Biases logging
conventions.
- The original training path supports mixed action horizons by padding to
`max_action_horizon` and masking padded horizon slots in the action expert
self-attention. This is useful when training across datasets with different
control frequencies. The LeRobot port currently targets single-dataset
fine-tuning, so `policy.chunk_size` overrides the checkpoint
`max_action_horizon` and horizon masking is not implemented yet. Support for
this mixed-horizon path is planned.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).

View File

@@ -1,39 +0,0 @@
# MolmoAct2
This repository contains the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
training, evaluation, checkpointing, and dataset compatibility.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## LIBERO Evaluation
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
scene stabilize and can degrade measured success. All LIBERO evaluation results
reported for this LeRobot implementation use `num_steps_wait=50`.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).

View File

@@ -1,39 +0,0 @@
# VLA-JEPA
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
---
## Architecture Overview
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
At inference time only the Qwen backbone and action head are used; the world model is not needed.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.

View File

@@ -161,7 +161,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--display_data=true
```
@@ -203,7 +203,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# --dataset.vcodec=auto \
--display_data=true
```

View File

@@ -1,186 +0,0 @@
# reBot B601-DM
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
<div style="display: flex; align-items: center; gap: 10px;">
<img
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
alt="reBot B601-DM follower arm at its zero position"
width="48%"
/>
<img
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
alt="reBot Arm 102 leader arm at its zero position"
width="48%"
/>
</div>
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
## Install LeRobot 🤗
Follow our [Installation Guide](./installation), then install the reBot support:
```bash
pip install -e ".[rebot]"
```
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
## Registered device types
| Type | Kind |
| ------------------------ | -------------------------------------------- |
| `rebot_b601_follower` | single-arm B601-DM follower robot |
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
The bimanual types compose two single-arm instances and namespace each arm's
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
## Find the USB ports
For each device, find the USB port associated with its motor bus using:
```bash
lerobot-find-port
```
<Tip warning={true}>
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
leader's USB serial port. You may also need to grant access to the serial
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
</Tip>
## Calibration
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
### Follower (B601-DM)
<hfoptions id="calibrate-follower">
<hfoption id="Single arm">
```bash
lerobot-calibrate \
--robot.type=rebot_b601_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=follower \
--robot.can_adapter=damiao
```
</hfoption>
<hfoption id="Dual arm">
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
```bash
lerobot-calibrate \
--robot.type=bi_rebot_b601_follower \
--robot.id=bi_follower \
--robot.left_arm_config.port=/dev/ttyACM0 \
--robot.left_arm_config.can_adapter=damiao \
--robot.right_arm_config.port=/dev/ttyACM1 \
--robot.right_arm_config.can_adapter=damiao
```
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
</hfoption>
</hfoptions>
### Leader (reBot Arm 102)
<hfoptions id="calibrate-leader">
<hfoption id="Single arm">
```bash
lerobot-calibrate \
--teleop.type=rebot_102_leader \
--teleop.port=/dev/ttyUSB0 \
--teleop.id=leader
```
</hfoption>
<hfoption id="Dual arm">
```bash
lerobot-calibrate \
--teleop.type=bi_rebot_102_leader \
--teleop.id=bi_leader \
--teleop.left_arm_config.port=/dev/ttyUSB0 \
--teleop.right_arm_config.port=/dev/ttyUSB1
```
</hfoption>
</hfoptions>
## Teleoperation
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
<hfoptions id="teleoperate">
<hfoption id="Single arm">
```bash
lerobot-teleoperate \
--robot.type=rebot_b601_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=follower \
--robot.can_adapter=damiao \
--teleop.type=rebot_102_leader \
--teleop.port=/dev/ttyUSB0 \
--teleop.id=leader
```
</hfoption>
<hfoption id="Dual arm">
The bimanual leader and follower reuse the single-arm classes; each arm is
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
```bash
lerobot-teleoperate \
--robot.type=bi_rebot_b601_follower \
--robot.id=bi_follower \
--robot.left_arm_config.port=/dev/ttyACM0 \
--robot.left_arm_config.can_adapter=damiao \
--robot.right_arm_config.port=/dev/ttyACM1 \
--robot.right_arm_config.can_adapter=damiao \
--teleop.type=bi_rebot_102_leader \
--teleop.id=bi_leader \
--teleop.left_arm_config.port=/dev/ttyUSB0 \
--teleop.right_arm_config.port=/dev/ttyUSB1
```
</hfoption>
</hfoptions>
<Tip>
The leader and follower share the same joint names (`shoulder_pan,
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
leader actions map directly onto the follower.
</Tip>
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
```bash
lerobot-teleoperate \
--robot.type=rebot_b601_follower \
--robot.port=/dev/ttyACM0 \
--robot.can_adapter=damiao \
--teleop.type=rebot_102_leader \
--teleop.port=/dev/ttyUSB0 \
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
```
## Recording datasets
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).

View File

@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing:
```bash
lerobot-rollout \
--strategy.type=base \
lerobot-record \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
# <- RTC optional, use when running on low power hardware \
# --inference.type=rtc \
# --inference.rtc.execution_horizon=10 \
# --inference.rtc.max_guidance_weight=10.0 \
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_red_leader_arm \
# --display_data=true #optional use if you want to see the camera stream \
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
```

View File

@@ -17,9 +17,9 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
| Parameter | CLI Flag | Type | Default | Description |
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `vcodec` | `--dataset.camera_encoder.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `30` | Max buffered frames per camera (~1s at 30fps). Consumes RAM |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
## 3. Performance Considerations
@@ -48,7 +48,7 @@ This parameter controls how many threads each encoder instance uses internally:
### Backpressure and Frame Dropping
Each camera has a bounded queue (`encoder_queue_maxsize`, default 30 frames). When the encoder can't keep up:
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
1. The queue fills up (consuming RAM)
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
@@ -82,15 +82,15 @@ Use HW encoding when:
### Available HW Encoders
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | --------------------------------------------------- |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder.vcodec=auto` |
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
> [!NOTE]
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
@@ -100,15 +100,15 @@ Use HW encoding when:
## 5. Troubleshooting
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
## 6. Recommended Configurations
@@ -146,7 +146,7 @@ On very constrained systems, streaming encoding may compete too heavily with the
# 2camsx 640x480x3 @30fps: Requires some tuning.
# Use H.264, disable streaming, consider batching encoding
lerobot-record --dataset.camera_encoder.vcodec=h264 --dataset.streaming_encoding=false ...
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
```
## 7. Closing note

View File

@@ -1,210 +0,0 @@
# Tools
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
emit structured invocations like `say(text="OK, starting now")` that the
runtime dispatches to a real implementation (TTS, controller, logger, …).
This page covers:
1. Where the tool catalog lives.
2. How the annotation pipeline produces tool-call atoms.
3. How to add your own tool.
## Where tools are declared
Two layers.
**The catalog** — a list of OpenAI-style function schemas — lives at
`meta/info.json["tools"]` on each dataset. Example:
```json
{
"features": { "...": "..." },
"tools": [
{
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak."
}
},
"required": ["text"]
}
}
}
]
}
```
Read it via the dataset metadata accessor:
```python
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
tools = meta.tools # list[dict] — OpenAI tool schemas
```
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
single-entry list with the canonical `say` schema. So unannotated
datasets and chat-template consumers keep working without any
configuration:
```python
prompt_str = tokenizer.apply_chat_template(
sample["messages"],
tools=meta.tools, # works either way
add_generation_prompt=False,
tokenize=False,
)
```
**The implementations** — runnable Python — will live under
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
not part of the catalog layer described here; today this layer ships
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
## Per-row tool _invocations_
The catalog above describes _what can be called_. The actual _call_ — the
function name plus the argument values — is stored per-row, on the
assistant atoms in `language_events`:
```python
{
"role": "assistant",
"content": null,
"style": null,
"timestamp": 12.4,
"camera": null,
"tool_calls": [
{ "type": "function",
"function": { "name": "say", "arguments": { "text": "On it." } } }
]
}
```
Recipes splice these into rendered messages via `tool_calls_from`:
```yaml
user_interjection_response:
bindings:
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- { role: user, content: "${task}", stream: high_level }
- {
role: assistant,
content: "${current_plan}",
stream: high_level,
target: true,
tool_calls_from: speech,
}
```
The model's training target is one assistant turn that carries both the
plan text _and_ the `say` tool call. At inference, the runtime parses
the generated text back into structured `tool_calls` and dispatches to
the matching implementation.
## How to add your own tool
> **Note:** Steps 2 and 3 below describe the runtime layer
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
> `get_tools(meta)`) which is not part of the catalog layer shipped
> today — those modules don't yet exist in the tree. Step 1 alone is
> enough to make the tool visible to the chat template via
> `meta.tools` so the model can learn to _generate_ the call;
> executing the call at inference requires the runtime layer.
Three steps. Concrete example: a `record_observation` tool the policy
can call to capture an extra observation outside the regular control
loop.
### Step 1 — declare the schema
Add an entry under `meta/info.json["tools"]`. Either edit the file
directly on disk _before_ running the annotation pipeline (it'll be
preserved) or hand it to `lerobot-annotate` via a config flag.
```json
{
"tools": [
{ "type": "function", "function": { "name": "say", "...": "..." } },
{
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a high-resolution still image for the user.",
"parameters": {
"type": "object",
"properties": {
"label": {
"type": "string",
"description": "Short label for the saved image."
}
},
"required": ["label"]
}
}
}
]
}
```
The schema follows OpenAI's function-calling convention exactly, so the
chat template can render it natively.
### Step 2 — implement the call
Create `src/lerobot/tools/record_observation.py`:
```python
from .base import Tool
from typing import Any
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
class RecordObservationTool:
name = "record_observation"
schema = RECORD_OBSERVATION_SCHEMA
def __init__(self, schema: dict | None = None, output_dir: str = "."):
self.output_dir = output_dir
def call(self, arguments: dict) -> str:
label = arguments["label"]
# ... save the latest camera frame to <output_dir>/<label>.png ...
return f"saved {label}.png"
```
One file per tool keeps dependencies isolated — `record_observation`
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
only the tools they need avoid heavy transitive deps.
### Step 3 — register it
Add to `src/lerobot/tools/registry.py`:
```python
from .record_observation import RecordObservationTool
TOOL_REGISTRY["record_observation"] = RecordObservationTool
```
That's it. At runtime `get_tools(meta)` looks up each schema in
`meta.tools`, instantiates the matching registered class, and returns
a name → instance dict the dispatcher can route into.
If you want to use a tool _without_ writing an implementation (e.g. for
training-time chat-template formatting only), step 1 alone is enough —
the model still learns to _generate_ the call. Steps 2 and 3 are only
needed to actually _execute_ it at inference.

View File

@@ -1,177 +0,0 @@
# TOPReward
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
## Overview
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
- A trajectory video (a sequence of frames).
- A task instruction (e.g. _"open the drawer"_).
it builds a chat prompt of the form
```text
<video>
"The above video shows a robot manipulation trajectory that completes the
following task: <instruction> Decide whether the above statement is True
or not. The answer is: True"
```
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward.
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the label-masking logic. The processor owns the tokeniser and builds the full chat prompt (EO-1/Robometer pattern).
## What the LeRobot integration covers
- Standard `reward_model.type=topreward` configuration through LeRobot.
- VLM loading via the `transformers` `Qwen3VLForConditionalGeneration` API.
- Prompt assembly + tokenisation in the processor (matching upstream `QwenClient.compute_instruction_reward`).
- `compute_reward()` returns one scalar log-prob per sample.
- LeRobot reward-model save/load — `save_pretrained` writes only `config.json` (the VLM is identified by `vlm_name`).
- An offline labeling script that writes a `topreward_progress.parquet` (SARM-compatible schema) for RA-BC and overlay.
The current LeRobot port supports the **Qwen3-VL client only**. Other upstream clients (Gemini, OpenAI, Gemma, Molmo) can be added as follow-up extras.
## Installation Requirements
1. Install LeRobot following the [Installation Guide](./installation).
2. Install the TOPReward optional extra:
```bash
pip install -e ".[topreward]"
```
or, with `uv` from a source checkout:
```bash
uv sync --extra topreward
```
This pulls in `transformers`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
## Model Inputs and Outputs
TOPReward expects:
- A trajectory video or sequence of frames.
- A natural-language task description.
In LeRobot datasets the preprocessor reads:
| Config field | Default | Meaning |
| ------------------------- | --------------------------- | --------------------------------------------- |
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
| `reward_model.task_key` | `task` | Key in complementary data for the task string |
| `reward_model.max_frames` | `16` | Cap on frames per sample |
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
The model returns:
- `compute_reward(batch)`: one log-probability per sample. Higher = better task-video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
## Usage
### Load the reward model directly
```python
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
cfg = TOPRewardConfig(
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
device="cuda",
)
reward_model = TOPRewardModel(cfg)
```
### Use the reward factory
```python
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
cfg = make_reward_model_config(
"topreward",
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
device="cuda",
image_key="observation.images.top",
)
reward_model = make_reward_model(cfg)
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
```
The preprocessor tokenises the full prompt (video + prefix + instruction suffix), writes Qwen-VL tensors + `prompt_length` under `observation.topreward.*`. The model reads those tensors, label-masks based on `prompt_length`, and extracts the log-prob reward.
### Offline dataset labeling
Write a `topreward_progress.parquet` for RA-BC training and overlay videos:
```bash
# Sparse-dense (15 anchors per episode, matches upstream)
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
--dataset-repo-id lerobot/libero_10_image \
--num-samples 15 \
--device cuda
```
Then render the progress overlay for any episode:
```bash
uv run examples/dataset/create_progress_videos.py \
--repo-id lerobot/libero_10_image \
--episode 0 \
--progress-file topreward_progress.parquet \
--gif
```
## Configuration Notes
### Prompt knobs
The default prompt mirrors the upstream paper:
```text
prompt_prefix = "The above video shows a robot manipulation trajectory that completes the following task: "
prompt_suffix_template = "{instruction} Decide whether the above statement is True or not. The answer is: True"
```
Both are exposed on `TOPRewardConfig` for ablation. The suffix template **must** contain `{instruction}`.
### Chat template
`add_chat_template=True` wraps the full prompt (including instruction) with the tokenizer's chat template before tokenisation. Default is `False`, matching the upstream paper's main experiments.
## Limitations
- The current LeRobot port is **inference-only and zero-shot**; `forward()` is not overridden and `is_trainable` returns `False`.
- Only the **Qwen3-VL family** is supported; other upstream clients are out of scope.
- TOPReward inherits the underlying VLM's biases.
## References
- [TOPReward project page](https://topreward.github.io/webpage/)
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
## Citation
```bibtex
@article{chen2026topreward,
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
Krishna, Ranjay},
journal={arXiv preprint arXiv:2602.19313},
year={2026}
}
```
## License
The original TOPReward codebase is MIT-licensed. The LeRobot port follows the LeRobot Apache 2.0 license; the wrapped Qwen3-VL weights are subject to the original Qwen license.

View File

@@ -117,10 +117,10 @@ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir outputs/pusht_video \
--operation.camera_encoder.vcodec libsvtav1 \
--operation.camera_encoder.pix_fmt yuv420p \
--operation.camera_encoder.g 2 \
--operation.camera_encoder.crf 30
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
@@ -147,7 +147,11 @@ lerobot-edit-dataset \
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)

View File

@@ -1,117 +0,0 @@
# Video encoding parameters
When video storage is enabled, LeRobot stores each camera stream as an **MP4** file instead of saving one image file per timestep. Video encoding compresses across time, which usually cuts dataset size and I/O compared to a pile of PNG, while keeping MP4 — a format every player and loader understands.
Encoding frames into an MP4 is a full FFmpeg pipeline: choice of encoder, pixel format, GOP/keyframes, quality vs. speed, and optional extra encoder flags. Most of these knobs are user-tunable through `camera_encoder`, a nested `VideoEncoderConfig` (`lerobot.configs.video.VideoEncoderConfig`) passed through PyAV.
You can set these parameters from the CLI with `--dataset.camera_encoder.<field>` (e.g. with `lerobot-record` or `lerobot-rollout`). The same block applies to every camera video stream in that run.
<Tip>
Video storage must be on for `camera_encoder` to have any effect —
`use_videos=True` in Python APIs, or `--dataset.video=true` on the CLI (the
recording default). With video off, inputs stay as images and `camera_encoder`
is ignored.
</Tip>
For details on **when** frames are written vs. encoded (streaming vs. post-episode), queues, and other top-level `--dataset.*` switches, see [Streaming Video Encoding](./streaming_video_encoding). For an encoding-parameter comparison and experiments, see the [video-benchmark Space](https://huggingface.co/spaces/lerobot/video-benchmark).
---
## Example
```bash
lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue \
--dataset.repo_id=<my_username>/<my_dataset_name> \
--dataset.num_episodes=2 \
--dataset.single_task="Grab the cube" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
--dataset.camera_encoder.vcodec=h264 \
--dataset.camera_encoder.preset=fast \
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \
--display_data=true
```
---
## Tuning parameters
<Tip warning={true}>
The defaults are tuned to balance **compression ratio**, **visual quality**, and **decoding/seek speed** for typical robotics datasets. Changing them can affect both recording (CPU load, frame drops) and training (decoding throughput, image quality).
Only override these parameters if you have a specific reason to, and measure the impact on your pipeline before relying on the new settings.
</Tip>
All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
| Parameter | Type | Default | Description |
| --------------- | ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `vcodec` | `str` | `"libsvtav1"` | Video codec name. `"auto"` picks the first available hardware encoder from a fixed preference list, falling back to `libsvtav1`. |
| `pix_fmt` | `str` | `"yuv420p"` | Output pixel format. Must be supported by the chosen codec in your FFmpeg build. |
| `g` | `int` | `2` | GOP size — a keyframe every `g` frames. Emitted as FFmpeg option `g`. |
| `crf` | `int` or `float` | `30` | Abstract quality value, mapped per codec (see the [mapping](#mapping-videoencoderconfig--ffmpeg-options) below). Lower → higher quality / larger output where the mapping is monotone. |
| `preset` | `int` or `str` | `12` \* | Encoder speed preset; meaning depends on the codec. <br/>\* When unset and `vcodec=libsvtav1`, LeRobot defaults to `12`. |
| `fast_decode` | `int` | `0` | `libsvtav1`: `02`, passed via `svtav1-params`. <br/>`h264` / `hevc` (software): if `>0`, sets `tune=fastdecode`. <br/>Other codecs: usually unused. |
| `video_backend` | `str` | `"pyav"` | Only `"pyav"` is currently implemented for video encoding. |
| `extra_options` | `dict` | `{}` | Extra FFmpeg or codec specific options merged after the structured fields above. Cannot override keys already set by those fields. |
---
## Persistence in dataset metadata
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
```json
{
"features": {
"observation.images.laptop": {
"dtype": "video",
"shape": [480, 640, 3],
"info": {
"video.height": 480,
"video.width": 640,
"video.codec": "h264",
"video.pix_fmt": "yuv420p",
"video.fps": 30,
"video.channels": 3,
"video.is_depth_map": false,
"video.g": 2,
"video.crf": 30,
"video.preset": "fast",
"video.fast_decode": 0,
"video.video_backend": "pyav",
"video.extra_options": { "tune": "film", "profile:v": "high", "bf": 2 }
}
}
}
}
```
Two sources contribute to the `info` block:
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
<Tip>
This block is populated **once**, from the **first** episode. It assumes every
episode in the dataset was encoded with the same `camera_encoder`. Changing
encoder settings partway through a recording is not supported — the
`info.json` will only reflect the parameters used for the first episode.
</Tip>
---
## Merging datasets
When aggregating datasets with `merge_datasets`, video files are concatenated as-is (no re-encoding), and encoder fields in `info.json` are merged per-key:
- **Stream-derived fields must match** across sources: `video.codec`, `video.pix_fmt`, `video.height`, `video.width`, `video.fps`. Otherwise FFmpeg's concat demuxer fails.
- **Encoder-tuning fields are merged loosely**: `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.extra_options`. If every source agrees, the value is kept; if not, it's set to `null` (or `{}` for `video.extra_options`) and a warning is logged.

View File

@@ -1,235 +0,0 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.

View File

@@ -15,12 +15,10 @@
# limitations under the License.
"""
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage:
python examples/dataset/create_progress_videos.py \
@@ -58,26 +56,22 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", progress_file],
allow_patterns=["meta/**", "sarm_progress.parquet"],
ignore_patterns=["*.mp4"],
)
)
@@ -221,28 +215,25 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path
def load_progress_data(
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / progress_file
parquet_path = local_path / "sarm_progress.parquet"
if not parquet_path.exists():
logging.warning("%s not found", progress_file)
logging.warning("sarm_progress.parquet not found")
return None
df = pd.read_parquet(parquet_path)
logging.info(" %s columns: %s", progress_file, list(df.columns))
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
logging.warning("No sarm_progress rows for episode %d", episode)
return None
episode_df = episode_df.sort_values("frame_index")
@@ -585,7 +576,6 @@ def process_dataset(
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
@@ -595,8 +585,6 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Path to the final output file, or None on failure.
@@ -604,7 +592,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode, progress_file)
local_path = download_episode_metadata(repo_id, episode)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -612,9 +600,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode, progress_file)
progress_data = load_progress_data(local_path, episode)
if progress_data is None:
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
logging.error("Could not load sarm_progress data. Skipping overlay.")
return None
logging.info(" Progress frames: %d", len(progress_data))
@@ -639,7 +627,7 @@ def process_dataset(
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
@@ -670,15 +658,6 @@ def main() -> None:
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -691,7 +670,6 @@ def main() -> None:
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
progress_file=args.progress_file,
)
if result:

View File

@@ -0,0 +1,244 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create videos with a Robometer progress overlay for one LeRobot dataset episode.
This is a lightweight smoke-test utility for Robometer checkpoints. It downloads
one episode video, samples a small number of frames, runs Robometer on those
frames, and reuses the progress overlay renderer from
``examples/dataset/create_progress_videos.py``.
Example:
uv run python examples/dataset/create_robometer_progress_videos.py \\
--repo-id lerobot/aloha_mobile_cabinet \\
--episode 0 \\
--reward-model-path lilkm/robometer-4b \\
--device cuda
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import cv2
import numpy as np
import torch
from examples.dataset.create_progress_videos import (
composite_progress_video,
convert_mp4_to_gif,
download_episode_metadata,
download_video_file,
load_episode_meta,
)
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
from lerobot.utils.utils import init_logging
def _default_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def sample_episode_frames(
video_path: Path,
*,
from_timestamp: float,
to_timestamp: float,
fps: float,
num_frames: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Sample RGB frames uniformly from an episode video segment.
Returns:
``(frames, frame_indices)`` where ``frames`` is ``(T,H,W,C)`` uint8 RGB
and ``frame_indices`` are local episode frame indices used for overlay.
"""
if num_frames <= 0:
raise ValueError(f"num_frames must be positive, got {num_frames}")
duration_seconds = to_timestamp - from_timestamp
total_frames = max(int(round(duration_seconds * fps)), 1)
frame_indices = np.linspace(0, total_frames - 1, num=min(num_frames, total_frames), dtype=int)
capture = cv2.VideoCapture(str(video_path))
frames: list[np.ndarray] = []
try:
for frame_idx in frame_indices:
timestamp = from_timestamp + frame_idx / fps
capture.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame_bgr = capture.read()
if not ret:
logging.warning("Could not read frame %d at %.3fs", frame_idx, timestamp)
continue
frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
finally:
capture.release()
if not frames:
raise RuntimeError(f"No frames could be sampled from {video_path}")
return np.stack(frames), frame_indices[: len(frames)]
def predict_robometer_progress(
frames: np.ndarray,
*,
task: str,
reward_model_path: str,
device: str,
) -> list[float]:
"""Run Robometer and return per-sampled-frame progress predictions."""
config = RobometerConfig(pretrained_path=reward_model_path, device=device, max_frames=None)
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
encoder = RobometerEncoderProcessorStep(
base_model_id=model.config.base_model_id,
use_multi_image=model.config.use_multi_image,
use_per_frame_progress_token=model.config.use_per_frame_progress_token,
max_frames=None,
)
batch = encoder.encode_samples([(frames, task)])
model_device = next(model.model.parameters()).device
inputs = {key: value.to(model_device) if hasattr(value, "to") else value for key, value in batch.items()}
model.eval()
with torch.no_grad():
progress_logits, success_logits = model._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits,
success_logits,
is_discrete_mode=model.config.use_discrete_progress,
)
return decoded["progress_pred"][0]
def process_dataset(
repo_id: str,
episode: int,
reward_model_path: str,
device: str,
camera_key: str | None,
output_dir: Path,
num_frames: int,
task: str | None = None,
create_gif: bool = False,
) -> Path:
safe_name = repo_id.replace("/", "_")
logging.info("Processing %s episode %d with Robometer %s", repo_id, episode, reward_model_path)
local_path = download_episode_metadata(repo_id, episode)
episode_meta = load_episode_meta(local_path, episode, camera_key)
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
task_name = task or episode_meta.get("task_name", "")
if not task_name:
raise ValueError("No task found in dataset metadata. Pass --task explicitly.")
frames, frame_indices = sample_episode_frames(
video_path,
from_timestamp=episode_meta["from_ts"],
to_timestamp=episode_meta["to_ts"],
fps=episode_meta["fps"],
num_frames=num_frames,
)
logging.info("Sampled %d frames for Robometer inference", len(frames))
progress = predict_robometer_progress(
frames,
task=task_name,
reward_model_path=reward_model_path,
device=device,
)
progress_data = np.stack([frame_indices, np.asarray(progress, dtype=np.float32)], axis=1)
logging.info("Progress predictions: %s", [round(float(value), 3) for value in progress])
output_path = output_dir / f"{safe_name}_ep{episode}_robometer_progress.mp4"
final_path = composite_progress_video(
video_path=video_path,
from_timestamp=episode_meta["from_ts"],
to_timestamp=episode_meta["to_ts"],
progress_data=progress_data,
output_path=output_path,
fps=episode_meta["fps"],
task_name=task_name,
)
if create_gif:
final_path = convert_mp4_to_gif(final_path)
return final_path
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with Robometer progress overlay for dataset episodes."
)
parser.add_argument("--repo-id", required=True, help="Hugging Face LeRobot dataset repo id.")
parser.add_argument("--episode", type=int, required=True, help="Episode index to visualize.")
parser.add_argument(
"--reward-model-path",
default="lilkm/robometer-4b",
help="Robometer checkpoint path or Hub repo id (e.g. lilkm/robometer-4b).",
)
parser.add_argument("--device", default=_default_device(), help="Torch device for Robometer inference.")
parser.add_argument(
"--camera-key",
default=None,
help="Camera observation key (e.g. observation.images.top). Auto-selects first camera if omitted.",
)
parser.add_argument(
"--task", default=None, help="Task description override if dataset metadata lacks one."
)
parser.add_argument(
"--num-frames",
type=int,
default=8,
help="Number of episode frames to sample for Robometer inference.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("progress_videos"),
help="Directory to write output files.",
)
parser.add_argument("--gif", action="store_true", help="Also generate a GIF from the MP4 output.")
args = parser.parse_args()
init_logging()
args.output_dir.mkdir(parents=True, exist_ok=True)
result = process_dataset(
repo_id=args.repo_id,
episode=args.episode,
reward_model_path=args.reward_model_path,
device=args.device,
camera_key=args.camera_key,
output_dir=args.output_dir,
num_frames=args.num_frames,
task=args.task,
create_gif=args.gif,
)
logging.info("Output: %s", result)
if __name__ == "__main__":
main()

View File

@@ -80,7 +80,7 @@
"}\n",
"\n",
"# Dataset\n",
"HF_USER = \"your_hf_username\" # `hf auth whoami` to find your username\n",
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
"DATASET_NAME = \"my_so101_dataset\"\n",
"TASK_DESCRIPTION = \"pick and place the block\"\n",
"NUM_EPISODES = 10\n",
@@ -291,34 +291,7 @@
"\n",
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
"\n",
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details.\n",
"\n",
"Recently ```lerobot-rollout``` was introduced, you can [read more about it here](https://huggingface.co/docs/lerobot/main/en/il_robots?eval=Base+mode+%28no+recording%29#run-inference-and-evaluate-your-policy)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-rollout\",\n",
" \"--strategy.type=base\",\n",
" f\"--policy.path={POLICY_PATH}\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" CAMERAS_FLAG,\n",
" f'--task=\"{TASK_DESCRIPTION}\"',\n",
" \"--duration=60\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"if you are using the V0.5.1 release you should use ```lerobot-record``` instead of rollout"
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
]
},
{

View File

@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.7.0,<5.0.0",
"datasets>=4.0.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
@@ -138,9 +138,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
# Common
av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
@@ -153,8 +151,6 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
@@ -178,9 +174,6 @@ unitree_g1 = [
"lerobot[pygame-dep]",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
# leader (motorbridge-smart-servo / FashionStar UART servos).
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
@@ -198,7 +191,6 @@ wallx = [
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
@@ -212,11 +204,10 @@ groot = [
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
topreward = ["lerobot[transformers-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -270,19 +261,16 @@ all = [
"lerobot[lekiwi]",
"lerobot[openarms]",
"lerobot[reachy2]",
"lerobot[rebot]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[diffusion]",
"lerobot[multi_task_dit]",
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -293,7 +281,6 @@ all = [
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[topreward]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
@@ -316,6 +303,7 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
lerobot-export-robometer="lerobot.scripts.lerobot_export_robometer:main"
# ---------------- Tool Configurations ----------------
@@ -409,11 +397,8 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"arange",
"is_compileable",
"ROBOTIS",
"OT_VALUE",
"VanderBilt"
"OT_VALUE"
]
# TODO: Uncomment when ready to use

View File

@@ -0,0 +1,164 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
"""Pinpoint exactly which rows of ``embed_tokens`` / ``lm_head`` differ.
Useful follow-up to ``scripts/verify_robometer_export.py`` when the verifier
reports a small tail of differing keys but you want to know whether the
diff is:
1. Concentrated in the 5 special-token rows added by ``resize_token_embeddings``
(expected non-determinism: mean-resize sampling differs between runs).
2. Spread across the full vocabulary (would point to a real loading bug).
Also confirms whether ``apply_upstream_checkpoint`` actually overwrites the
embed/lm-head tensors when loading the upstream state dict (vs. silently
skipping them due to a key mismatch).
"""
from __future__ import annotations
import argparse
import sys
import torch
from safetensors.torch import load_file
from lerobot.configs.rewards import RewardModelConfig
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
from lerobot.rewards.robometer._upstream_loader import (
_download_robometer_snapshot,
_remap_state_dict_keys,
_resolve_checkpoint_safetensors_files,
apply_upstream_checkpoint,
)
EMBED_KEY = "model.model.language_model.embed_tokens.weight"
LMHEAD_KEY = "model.lm_head.weight"
def _load_upstream(path: str) -> RobometerRewardModel:
cfg = RobometerConfig(pretrained_path=path, device="cpu")
model = RobometerRewardModel(cfg)
apply_upstream_checkpoint(model, path)
model.eval()
return model
def _load_lerobot(path: str) -> RobometerRewardModel:
cfg = RewardModelConfig.from_pretrained(path)
if not isinstance(cfg, RobometerConfig):
raise TypeError(f"Expected RobometerConfig, got {type(cfg)}")
cfg.pretrained_path = path
cfg.device = "cpu"
return RobometerRewardModel.from_pretrained(path, config=cfg)
def _inspect_upstream_state_dict(upstream_path: str, model: RobometerRewardModel) -> None:
"""Dump the upstream state-dict view of the embed/lm-head tensors.
Loads the raw upstream safetensors (pre-remap), runs the remapper, and
reports whether the embed/lm-head keys survive into the merged dict that
eventually hits ``model.load_state_dict``.
"""
snapshot_dir = _download_robometer_snapshot(upstream_path)
files = _resolve_checkpoint_safetensors_files(snapshot_dir)
merged: dict[str, torch.Tensor] = {}
for path in files:
merged.update(load_file(str(path)))
remapped = _remap_state_dict_keys(merged, model)
print(f"\n=== Upstream state-dict inspection (snapshot at {snapshot_dir}) ===")
print(f"raw keys (before remap) : {len(merged)}")
print(f"keys after remap : {len(remapped)}")
print(f"model expects (state_dict): {len(model.state_dict())}")
expected = set(model.state_dict())
present_after_remap = set(remapped) & expected
print(f"keys present after remap : {len(present_after_remap)}")
missing_keys = expected - set(remapped)
print(f"keys missing from remap : {len(missing_keys)}")
if missing_keys:
sample = list(missing_keys)[:10]
print(f" sample missing keys : {sample}")
unexpected_keys = set(remapped) - expected
print(f"keys unexpected by model : {len(unexpected_keys)}")
if unexpected_keys:
sample = list(unexpected_keys)[:10]
print(f" sample unexpected keys : {sample}")
for key in (EMBED_KEY, LMHEAD_KEY):
present = key in remapped
shape = tuple(remapped[key].shape) if present else None
print(f" {key:60s} present={present}, shape={shape}")
def _diff_embed(name: str, a: torch.Tensor, b: torch.Tensor, special_token_count: int) -> None:
a = a.float()
b = b.float()
if a.shape != b.shape:
print(f"{name} shape mismatch: {tuple(a.shape)} vs {tuple(b.shape)}")
return
abs_diff = (a - b).abs()
per_row_max = abs_diff.max(dim=1).values
nz_rows = (per_row_max > 0).nonzero(as_tuple=True)[0].tolist()
print(f"\n=== {name} (shape {tuple(a.shape)}) ===")
print(f"global max|Δ| = {abs_diff.max().item():.3e}")
print(f"rows with any diff = {len(nz_rows)}")
if nz_rows:
first = nz_rows[:10]
last = nz_rows[-10:]
print(f" first nonzero rows = {first}")
print(f" last nonzero rows = {last}")
vocab_size = a.shape[0]
base_vocab = vocab_size - special_token_count
special_rows = list(range(base_vocab, vocab_size))
in_special = [r for r in nz_rows if r in special_rows]
out_special = [r for r in nz_rows if r not in special_rows]
print(
f" diffs in special-token rows ({base_vocab}..{vocab_size - 1}): {len(in_special)}/{special_token_count}"
)
print(f" diffs in base-vocab rows (0..{base_vocab - 1}) : {len(out_special)}")
for r in special_rows:
print(
f" row {r}: max|Δ|={per_row_max[r].item():.3e}, "
f"upstream_norm={a[r].norm().item():.3e}, lerobot_norm={b[r].norm().item():.3e}"
)
def main() -> int:
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--upstream", required=True)
parser.add_argument("--lerobot", required=True)
parser.add_argument(
"--special-token-count",
type=int,
default=5,
help="Number of special tokens Robometer adds. Defaults to len(ROBOMETER_SPECIAL_TOKENS)=5.",
)
args = parser.parse_args()
print(f"Loading upstream: {args.upstream}")
upstream = _load_upstream(args.upstream)
print(f"Loading LeRobot-format: {args.lerobot}")
lerobot = _load_lerobot(args.lerobot)
_inspect_upstream_state_dict(args.upstream, upstream)
sd_u, sd_l = upstream.state_dict(), lerobot.state_dict()
for key in (EMBED_KEY, LMHEAD_KEY):
if key not in sd_u or key not in sd_l:
print(f"❌ key missing: {key} (upstream={key in sd_u}, lerobot={key in sd_l})")
continue
_diff_embed(key, sd_u[key], sd_l[key], args.special_token_count)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
"""Extract one LIBERO episode for Robometer parity testing.
Loads a LeRobot LIBERO (or any video-bearing LeRobot) dataset, picks one
episode, samples ``--num-frames`` frames uniformly across its duration
(matching upstream Robometer's default of 8 frames), and saves them to
``.npz`` plus a sidecar ``.txt`` task file.
The ``.npz`` layout (``frames`` key, ``(T, H, W, C) uint8``) is what upstream
``example_inference_local.py`` consumes, so the same file feeds both pipelines
and frame sampling cannot drift.
Workflow:
1. Run this script (LeRobot env) to produce ``frames.npz`` + ``task.txt``.
2. Pass them to upstream ``scripts/example_inference_local.py``
(upstream env) to produce reference progress / success outputs.
3. Pass the same ``frames.npz`` to ``scripts/parity_robometer.py``
(LeRobot env) to compare both sides.
Example:
uv run python scripts/extract_libero_episode_for_parity.py \\
--repo-id lerobot/libero_10_image \\
--episode 0 \\
--num-frames 8 \\
--out-dir /tmp/libero_ep0
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def _pick_visual_feature(features: dict, requested: str | None) -> str:
"""Return a visual feature key, preferring ``requested`` when given."""
visual_keys = [
key
for key, ft in features.items()
if getattr(ft, "type", None) == FeatureType.VISUAL or ft.get("dtype", "") == "video"
]
if not visual_keys:
raise ValueError(f"Dataset has no visual feature; available: {list(features)}")
if requested is not None:
if requested not in visual_keys:
raise ValueError(f"Camera key {requested!r} not in dataset visual features {visual_keys}")
return requested
return visual_keys[0]
def _frame_uint8_hwc(tensor: torch.Tensor) -> np.ndarray:
"""Convert a LeRobotDataset video frame to ``uint8`` ``(H, W, C)`` RGB."""
arr = tensor.detach().cpu().numpy()
if arr.ndim == 3 and arr.shape[0] in (1, 3):
arr = arr.transpose(1, 2, 0)
if arr.dtype != np.uint8:
arr = np.clip(arr * 255.0 if arr.max() <= 1.0 + 1e-3 else arr, 0, 255).astype(np.uint8)
if arr.shape[-1] == 1:
arr = np.repeat(arr, 3, axis=-1)
return arr
def main() -> int:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--repo-id",
default="lerobot/libero_10_image",
help="LeRobot LIBERO (or other) dataset repo id (default: lerobot/libero_10_image).",
)
parser.add_argument("--episode", type=int, default=0, help="Episode index.")
parser.add_argument(
"--camera-key",
default=None,
help="Visual feature key (e.g. observation.images.image). Auto-selects first if omitted.",
)
parser.add_argument(
"--num-frames",
type=int,
default=8,
help="Number of frames to sample uniformly (default: 8 — Robometer's training-time default).",
)
parser.add_argument(
"--out-dir",
type=Path,
default=Path("outputs/robometer_parity/libero"),
help="Directory to write frames.npz / task.txt / frame_indices.npy.",
)
args = parser.parse_args()
print(f"Loading {args.repo_id} (episode {args.episode})...")
dataset = LeRobotDataset(args.repo_id, episodes=[args.episode])
camera_key = _pick_visual_feature(dataset.features, args.camera_key)
print(f"Using camera key: {camera_key}")
ep_from = int(dataset.episode_data_index["from"][0].item())
ep_to = int(dataset.episode_data_index["to"][0].item())
total_frames = ep_to - ep_from
if total_frames <= 0:
print(f"ERROR: episode {args.episode} has no frames.", file=sys.stderr)
return 1
print(f"Episode has {total_frames} frames; sampling {args.num_frames} uniformly.")
indices = np.linspace(0, total_frames - 1, num=min(args.num_frames, total_frames), dtype=int)
frames: list[np.ndarray] = []
task: str = ""
for offset in indices:
sample = dataset[ep_from + int(offset)]
frame_tensor = sample[camera_key]
frames.append(_frame_uint8_hwc(frame_tensor))
if not task:
task = sample.get("task", "") or ""
if not task:
print("ERROR: episode has no task description in metadata.", file=sys.stderr)
return 1
frames_array = np.stack(frames)
args.out_dir.mkdir(parents=True, exist_ok=True)
frames_path = args.out_dir / "frames.npz"
task_path = args.out_dir / "task.txt"
indices_path = args.out_dir / "frame_indices.npy"
np.savez(frames_path, frames=frames_array)
task_path.write_text(task + "\n", encoding="utf-8")
np.save(indices_path, indices)
print()
print(f"Wrote {frames_path} (shape={frames_array.shape}, dtype={frames_array.dtype})")
print(f"Wrote {task_path} (task={task!r})")
print(f"Wrote {indices_path} (frame_indices={indices.tolist()})")
print()
print("Next steps:")
print(" # in upstream env (where `robometer` is importable):")
print(
f" python third_party/robometer/scripts/example_inference_local.py \\\n"
f" --model-path robometer/Robometer-4B \\\n"
f" --video {frames_path} \\\n"
f' --task "{task}" \\\n'
f" --out {args.out_dir / 'upstream.npy'}"
)
print()
print(" # back in LeRobot env:")
print(
f" uv run python scripts/parity_robometer.py \\\n"
f" --frames {frames_path} \\\n"
f' --task "{task}" \\\n'
f" --upstream-progress {args.out_dir / 'upstream.npy'} \\\n"
f" --upstream-success {args.out_dir / 'upstream_success_probs.npy'}"
)
return 0
if __name__ == "__main__":
sys.exit(main())

232
scripts/parity_robometer.py Normal file
View File

@@ -0,0 +1,232 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
"""Functional parity check: LeRobot Robometer vs. upstream Robometer.
Runs the in-tree :class:`RobometerRewardModel` on the same frames + task that
upstream Robometer was run on, and compares per-frame progress / success
predictions against reference outputs saved by upstream's
``scripts/example_inference_local.py``.
Workflow:
1. In the upstream Robometer environment (where ``robometer`` is importable),
run::
python third_party/robometer/scripts/example_inference_local.py \\
--model-path robometer/Robometer-4B \\
--video /path/to/episode.mp4 \\
--task "Open the drawer" \\
--fps 1.0 \\
--out /tmp/robometer_upstream.npy
This produces:
- ``/tmp/robometer_upstream.npy`` (progress predictions)
- ``/tmp/robometer_upstream_success_probs.npy`` (success probabilities)
2. Extract the exact same frames the upstream script used, save as ``.npz``::
# quick helper: extract frames at the same fps and save as .npz
python -c "
from third_party.robometer.scripts.example_inference_local import load_frames_input
import numpy as np
frames = load_frames_input('/path/to/episode.mp4', fps=1.0, max_frames=512)
np.savez('/tmp/robometer_frames.npz', frames=frames)
"
3. In this LeRobot env, run this script::
uv run python scripts/parity_robometer.py \\
--frames /tmp/robometer_frames.npz \\
--task "Open the drawer" \\
--upstream-progress /tmp/robometer_upstream.npy \\
--upstream-success /tmp/robometer_upstream_success_probs.npy \\
--lerobot-model lilkm/robometer-4b
"""
from __future__ import annotations
import argparse
import sys
import numpy as np
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
def _load_frames(path: str) -> np.ndarray:
"""Load frames from .npy/.npz. Expects (T, H, W, C) uint8."""
if path.endswith(".npy"):
frames = np.load(path)
elif path.endswith(".npz"):
with np.load(path, allow_pickle=False) as npz:
frames = npz["frames"].copy() if "frames" in npz else next(iter(npz.values())).copy()
else:
raise ValueError(f"Frames must be .npy or .npz (got {path!r}).")
if frames.dtype != np.uint8:
frames = np.clip(frames, 0, 255).astype(np.uint8)
if frames.ndim != 4:
raise ValueError(f"Frames must be 4D (T,H,W,C); got shape {frames.shape}.")
if frames.shape[-1] not in (1, 3):
# Probably (T,C,H,W) — transpose
if frames.shape[1] in (1, 3):
frames = frames.transpose(0, 2, 3, 1)
else:
raise ValueError(f"Cannot interpret frame channel layout: {frames.shape}.")
return frames
def _run_lerobot(
frames: np.ndarray,
task: str,
model_path: str,
device: str,
) -> tuple[np.ndarray, np.ndarray]:
"""Run LeRobot's Robometer on the given frames; return (progress, success)."""
cfg = RobometerConfig(pretrained_path=model_path, device=device, max_frames=None)
model = RobometerRewardModel.from_pretrained(model_path, config=cfg)
encoder = RobometerEncoderProcessorStep(
base_model_id=model.config.base_model_id,
use_multi_image=model.config.use_multi_image,
use_per_frame_progress_token=model.config.use_per_frame_progress_token,
max_frames=None,
)
batch = encoder.encode_samples([(frames, task)])
model_device = next(model.model.parameters()).device
inputs = {key: value.to(model_device) if hasattr(value, "to") else value for key, value in batch.items()}
model.eval()
with torch.no_grad():
progress_logits, success_logits = model._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits,
success_logits,
is_discrete_mode=model.config.use_discrete_progress,
)
progress = np.asarray(decoded["progress_pred"][0], dtype=np.float32)
success = (
np.asarray(decoded["success_probs"][0], dtype=np.float32)
if decoded["success_probs"]
else np.array([], dtype=np.float32)
)
return progress, success
def _compare(name: str, lerobot: np.ndarray, upstream: np.ndarray, atol: float, rtol: float) -> bool:
print(f"\n=== {name} ===")
if lerobot.shape != upstream.shape:
print(f"shape mismatch: lerobot={lerobot.shape} upstream={upstream.shape}")
return False
abs_diff = np.abs(lerobot - upstream)
rel_diff = abs_diff / (np.abs(upstream) + 1e-12)
print(f"shape : {lerobot.shape}")
print(f"max |Δ| : {abs_diff.max():.3e}")
print(f"mean |Δ| : {abs_diff.mean():.3e}")
print(f"max rel |Δ| : {rel_diff.max():.3e}")
print(f"lerobot[:5] : {lerobot[:5]}")
print(f"upstream[:5] : {upstream[:5]}")
within_tol = bool(np.allclose(lerobot, upstream, atol=atol, rtol=rtol))
print(f"allclose(atol={atol}, rtol={rtol}) -> {within_tol}")
return within_tol
def main() -> int:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--frames",
required=True,
help=".npy / .npz file with the exact frames upstream was run on (T,H,W,C uint8).",
)
parser.add_argument("--task", required=True, help="Task instruction string.")
parser.add_argument(
"--upstream-progress",
required=True,
help="Reference progress .npy saved by upstream example_inference_local.py.",
)
parser.add_argument(
"--upstream-success",
default=None,
help="Optional reference success_probs .npy. If omitted, success comparison is skipped.",
)
parser.add_argument(
"--lerobot-model",
default="lilkm/robometer-4b",
help="LeRobot-format Robometer Hub repo id or local path.",
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for the LeRobot model (default: cuda if available).",
)
parser.add_argument(
"--atol",
type=float,
default=1e-3,
help="Absolute tolerance for allclose (default: 1e-3; bf16 round-trip headroom).",
)
parser.add_argument(
"--rtol",
type=float,
default=1e-2,
help="Relative tolerance for allclose (default: 1e-2).",
)
parser.add_argument(
"--out-prefix",
default="lerobot_robometer_outputs",
help="Save the LeRobot outputs as <prefix>_progress.npy / <prefix>_success.npy.",
)
args = parser.parse_args()
# 0. Sanity: confirm the LeRobot config is a RobometerConfig.
cfg = RewardModelConfig.from_pretrained(args.lerobot_model)
if not isinstance(cfg, RobometerConfig):
print(f"ERROR: {args.lerobot_model!r} does not resolve to a RobometerConfig.", file=sys.stderr)
return 2
# 1. Load frames + task + upstream reference outputs.
frames = _load_frames(args.frames)
upstream_progress = np.load(args.upstream_progress).astype(np.float32)
upstream_success = (
np.load(args.upstream_success).astype(np.float32) if args.upstream_success is not None else None
)
print(f"Loaded {frames.shape[0]} frames at {frames.shape[1:]}, task={args.task!r}")
print(f"LeRobot model: {args.lerobot_model} device: {args.device}")
# 2. Run LeRobot pipeline.
progress, success = _run_lerobot(frames, args.task, args.lerobot_model, args.device)
np.save(f"{args.out_prefix}_progress.npy", progress)
if success.size > 0:
np.save(f"{args.out_prefix}_success.npy", success)
print(f"Saved LeRobot outputs to {args.out_prefix}_progress.npy / _success.npy")
# 3. Compare to upstream references.
progress_ok = _compare("progress", progress, upstream_progress, args.atol, args.rtol)
if upstream_success is not None and success.size > 0:
success_ok = _compare("success_probs", success, upstream_success, args.atol, args.rtol)
else:
success_ok = True
print("\n(skipping success comparison — upstream success file not provided)")
print()
if progress_ok and success_ok:
print("Parity check passed.")
return 0
print("Parity check FAILED.")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,362 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
"""Run LeRobot Robometer parity against upstream Robometer's bundled examples.
Upstream Robometer ships three reference videos with their pre-computed
progress / success outputs at
``third_party/robometer/scripts/example_videos/``::
soar_put_green_stick_in_brown_bowl.mp4
+ soar_put_green_stick_in_brown_bowl_rewards.npy (progress)
+ soar_put_green_stick_in_brown_bowl_rewards_success_probs.npy (success)
berkeley_rpt_stack_cup.mp4
+ berkeley_rpt_stack_cup_rewards.npy
+ berkeley_rpt_stack_cup_rewards_success_probs.npy
jaco_play_pick_up_green_cup.mp4
+ pick_up_green_cup_rewards.npy
+ pick_up_green_cup_rewards_success_probs.npy
This script:
1. Decodes each video at upstream's sampling fps using ``av`` (PyAV), with the
same linspace-over-total-frames logic as upstream's ``extract_frames``.
2. Runs the LeRobot ``RobometerRewardModel`` on those frames + the task from
upstream's README.
3. Compares per-frame progress / success to the pre-saved upstream outputs.
This means you do **not** need to install upstream Robometer to confirm parity.
Run::
uv run python scripts/parity_robometer_upstream_examples.py \\
--lerobot-model lilkm/robometer-4b \\
--device cuda \\
--decoder decord
The number of frames sampled per video is derived from the length of each
upstream ``.npy`` reference, so the script does not need a ``--fps`` argument
(the README documents ``fps=3`` for SOAR / Berkeley, but the Jaco Play
reference was generated with a different fps).
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
try:
import decord # type: ignore
_HAS_DECORD = True
except ImportError:
decord = None # type: ignore
_HAS_DECORD = False
try:
import av
_HAS_AV = True
except ImportError:
av = None # type: ignore
_HAS_AV = False
EXAMPLES = [
{
"name": "soar_put_green_stick_in_brown_bowl",
"video": "soar_put_green_stick_in_brown_bowl.mp4",
"task": "Put green stick in brown bowl",
"progress_npy": "soar_put_green_stick_in_brown_bowl_rewards.npy",
"success_npy": "soar_put_green_stick_in_brown_bowl_rewards_success_probs.npy",
},
{
"name": "berkeley_rpt_stack_cup",
"video": "berkeley_rpt_stack_cup.mp4",
"task": "Pick up the yellow cup and stack it on the other cup",
"progress_npy": "berkeley_rpt_stack_cup_rewards.npy",
"success_npy": "berkeley_rpt_stack_cup_rewards_success_probs.npy",
},
{
"name": "jaco_play_pick_up_green_cup",
"video": "jaco_play_pick_up_green_cup.mp4",
"task": "Pick up the green cup",
"progress_npy": "pick_up_green_cup_rewards.npy",
"success_npy": "pick_up_green_cup_rewards_success_probs.npy",
},
]
def _extract_frames_decord(video_path: Path, num_frames: int) -> tuple[np.ndarray, str]:
"""Sample ``num_frames`` indices uniformly from the video using decord.
Mirrors upstream's ``extract_frames`` indexing
(``third_party/robometer/scripts/example_inference.py``): a
``np.linspace(0, total_frames-1, num_frames)`` lookup over decord's
``VideoReader``. We pass ``num_frames`` explicitly (derived from the
upstream reference output length) so we don't have to guess what ``fps``
upstream actually used when generating each saved ``.npy`` — the file
length is the ground truth.
"""
vr = decord.VideoReader(str(video_path), num_threads=1)
total_frames = len(vr)
if total_frames == 0:
raise RuntimeError(f"No decodable frames in {video_path}.")
desired_frames = max(1, min(int(num_frames), total_frames))
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
frames = vr.get_batch(indices).asnumpy()
native_fps = float(vr.get_avg_fps()) or 1.0
return frames, f"decord total={total_frames} native_fps={native_fps:.3f}"
def _extract_frames_av(video_path: Path, num_frames: int) -> tuple[np.ndarray, str]:
"""PyAV fallback for environments without decord.
PyAV and decord can disagree on ``total_frames`` for the same container,
so the sampled frame indices can drift. Install ``decord`` for a real
parity check; this fallback is for smoke tests only.
"""
container = av.open(str(video_path))
stream = container.streams.video[0]
native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0)
rgb_frames: list[np.ndarray] = []
for frame in container.decode(stream):
rgb_frames.append(frame.to_ndarray(format="rgb24"))
container.close()
total_frames = len(rgb_frames)
if total_frames == 0:
raise RuntimeError(f"No decodable frames in {video_path}.")
desired_frames = max(1, min(int(num_frames), total_frames))
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int)
frames = np.stack([rgb_frames[i] for i in indices])
return frames, f"av total={total_frames} native_fps={native_fps:.3f}"
def _extract_frames(video_path: Path, num_frames: int, prefer: str) -> tuple[np.ndarray, str]:
"""Decoder dispatch. ``prefer`` is ``"decord"`` | ``"av"`` | ``"auto"``."""
if prefer == "decord":
if not _HAS_DECORD:
raise RuntimeError("decord requested but not installed (`uv pip install decord`).")
return _extract_frames_decord(video_path, num_frames)
if prefer == "av":
if not _HAS_AV:
raise RuntimeError("av requested but not installed.")
return _extract_frames_av(video_path, num_frames)
# auto
if _HAS_DECORD:
return _extract_frames_decord(video_path, num_frames)
if _HAS_AV:
return _extract_frames_av(video_path, num_frames)
raise RuntimeError("No video decoder available (install `decord` or `av`).")
def _pearson(a: np.ndarray, b: np.ndarray) -> float:
"""Pearson correlation; returns 1.0 for constant inputs (no signal to align)."""
a = a.astype(np.float64)
b = b.astype(np.float64)
if a.size < 2:
return 1.0
da = a - a.mean()
db = b - b.mean()
denom = float(np.sqrt((da * da).sum()) * np.sqrt((db * db).sum()))
if denom == 0:
return 1.0
return float((da * db).sum() / denom)
def _run_lerobot(
model: RobometerRewardModel,
encoder: RobometerEncoderProcessorStep,
frames: np.ndarray,
task: str,
) -> tuple[np.ndarray, np.ndarray]:
batch = encoder.encode_samples([(frames, task)])
device = next(model.model.parameters()).device
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in batch.items()}
model.eval()
with torch.no_grad():
progress_logits, success_logits = model._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits, success_logits, is_discrete_mode=model.config.use_discrete_progress
)
progress = np.asarray(decoded["progress_pred"][0], dtype=np.float32)
success = (
np.asarray(decoded["success_probs"][0], dtype=np.float32)
if decoded["success_probs"]
else np.array([], dtype=np.float32)
)
return progress, success
def _compare(
name: str,
lerobot: np.ndarray,
upstream: np.ndarray,
*,
atol: float,
pearson_min: float,
) -> bool:
if lerobot.shape != upstream.shape:
print(f" {name:8s} SHAPE MISMATCH lerobot={lerobot.shape} upstream={upstream.shape}")
return False
abs_diff = np.abs(lerobot - upstream)
pearson = _pearson(lerobot, upstream)
abs_ok = bool(abs_diff.max() <= atol)
pearson_ok = bool(pearson >= pearson_min)
verdict = "PASS" if (abs_ok or pearson_ok) else "FAIL"
print(
f" {name:8s} shape={lerobot.shape} max|Δ|={abs_diff.max():.3e} "
f"mean|Δ|={abs_diff.mean():.3e} pearson={pearson:.4f} "
f"(atol={atol:.0e} pearson_min={pearson_min:.3f}) -> {verdict}"
)
return abs_ok or pearson_ok
def main() -> int:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--examples-dir",
type=Path,
default=Path("third_party/robometer/scripts/example_videos"),
help="Directory containing the upstream Robometer example mp4s + .npy outputs.",
)
parser.add_argument(
"--lerobot-model",
default="lilkm/robometer-4b",
help="LeRobot-format Robometer Hub repo id or local path.",
)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for the LeRobot model.",
)
parser.add_argument(
"--decoder",
choices=("auto", "decord", "av"),
default="auto",
help=(
"Video decoder. ``auto`` prefers decord (matches upstream) and falls back to av. "
"Force ``decord`` for a clean parity check."
),
)
parser.add_argument(
"--progress-atol",
type=float,
default=1e-2,
help="Absolute tolerance for the progress array. Default 1e-2 covers CUDA bf16 noise.",
)
parser.add_argument(
"--success-atol",
type=float,
default=1e-1,
help=(
"Absolute tolerance for the success array. Looser than progress because "
"``sigmoid`` amplifies logit-space noise near 0.5."
),
)
parser.add_argument(
"--pearson-min",
type=float,
default=0.99,
help="Minimum Pearson correlation for a PASS verdict (per array).",
)
args = parser.parse_args()
if args.decoder == "av" or (args.decoder == "auto" and not _HAS_DECORD):
print(
"WARNING: using PyAV decoder. PyAV's total-frame count can differ from decord's, "
"which propagates into different sampled-frame indices. Install `decord` and "
"re-run for a clean parity check.",
file=sys.stderr,
)
examples_dir = args.examples_dir.resolve()
if not examples_dir.is_dir():
print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr)
return 2
# Sanity-check the LeRobot config is a RobometerConfig before loading weights.
cfg = RewardModelConfig.from_pretrained(args.lerobot_model)
if not isinstance(cfg, RobometerConfig):
print(f"ERROR: {args.lerobot_model!r} did not resolve to a RobometerConfig.", file=sys.stderr)
return 2
print(f"Loading LeRobot Robometer from {args.lerobot_model} on {args.device}...")
cfg.pretrained_path = args.lerobot_model
cfg.device = args.device
model = RobometerRewardModel.from_pretrained(args.lerobot_model, config=cfg)
encoder = RobometerEncoderProcessorStep(
base_model_id=model.config.base_model_id,
use_multi_image=model.config.use_multi_image,
use_per_frame_progress_token=model.config.use_per_frame_progress_token,
max_frames=None,
)
all_ok = True
for ex in EXAMPLES:
video_path = examples_dir / ex["video"]
upstream_progress_path = examples_dir / ex["progress_npy"]
upstream_success_path = examples_dir / ex["success_npy"]
missing = [p for p in (video_path, upstream_progress_path, upstream_success_path) if not p.exists()]
if missing:
print(f"[skip] {ex['name']}: missing {[str(m) for m in missing]}")
all_ok = False
continue
print(f"\n=== {ex['name']} ===")
print(f" task: {ex['task']!r}")
# Trust the upstream reference array as the source of truth for how
# many frames to sample. The README documents fps=3 for SOAR/Berkeley
# but Jaco Play was generated with a different fps, so any hardcoded
# ``--fps`` mismatches at least one example. The npy length always
# tells us what upstream actually used.
upstream_progress = np.load(upstream_progress_path).astype(np.float32)
upstream_success = np.load(upstream_success_path).astype(np.float32)
target_num_frames = int(upstream_progress.shape[0])
frames, decoder_info = _extract_frames(video_path, target_num_frames, prefer=args.decoder)
print(
f" decoded {frames.shape[0]} frames (matches upstream npy length); "
f"shape={frames.shape} [{decoder_info}]"
)
progress, success = _run_lerobot(model, encoder, frames, ex["task"])
progress_ok = _compare(
"progress",
progress,
upstream_progress,
atol=args.progress_atol,
pearson_min=args.pearson_min,
)
success_ok = _compare(
"success",
success,
upstream_success,
atol=args.success_atol,
pearson_min=args.pearson_min,
)
verdict = "PASS" if (progress_ok and success_ok) else "FAIL"
print(f" -> {verdict}")
all_ok = all_ok and progress_ok and success_ok
print()
if all_ok:
print("All upstream example parity checks passed.")
return 0
print("Some upstream example parity checks FAILED.")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Verify that a LeRobot-format Robometer is byte-equivalent to its upstream source.
Run this once after publishing a LeRobot-format Robometer to the Hub, before
flipping the default `RobometerConfig.pretrained_path` to it. It loads both
the upstream snapshot and the re-exported copy, compares state dicts, and
prints a clear pass/fail summary.
Example:
python scripts/verify_robometer_export.py \\
--upstream robometer/Robometer-4B \\
--lerobot lerobot/robometer-4b
python scripts/verify_robometer_export.py \\
--upstream robometer/Robometer-4B \\
--lerobot ./robometer-4b-lerobot # local folder also works
"""
from __future__ import annotations
import argparse
import sys
from lerobot.configs.rewards import RewardModelConfig
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
from lerobot.rewards.robometer._upstream_loader import apply_upstream_checkpoint
def _load_upstream(path: str) -> RobometerRewardModel:
# Fresh ``RobometerConfig`` (``vlm_config=None``) triggers
# ``RobometerRewardModel.__init__``'s upstream-matching path: download
# base Qwen, resize for ROBOMETER_SPECIAL_TOKENS. The subsequent
# ``apply_upstream_checkpoint`` call resizes again if the checkpoint's
# vocab differs (e.g. upstream was trained against an older Qwen).
cfg = RobometerConfig(pretrained_path=path, device="cpu")
model = RobometerRewardModel(cfg)
apply_upstream_checkpoint(model, path)
model.eval()
return model
def _load_lerobot(path: str) -> RobometerRewardModel:
cfg = RewardModelConfig.from_pretrained(path)
if not isinstance(cfg, RobometerConfig):
raise TypeError(f"Expected RobometerConfig in LeRobot export, got {type(cfg)}")
cfg.pretrained_path = path
cfg.device = "cpu"
return RobometerRewardModel.from_pretrained(path, config=cfg)
def compare_state_dicts(a: RobometerRewardModel, b: RobometerRewardModel) -> bool:
sd_a, sd_b = a.state_dict(), b.state_dict()
keys_a, keys_b = set(sd_a), set(sd_b)
missing = keys_a - keys_b
extra = keys_b - keys_a
if missing:
print(f"{len(missing)} keys missing in LeRobot-format model (sample: {list(missing)[:5]})")
if extra:
print(f"{len(extra)} extra keys in LeRobot-format model (sample: {list(extra)[:5]})")
if missing or extra:
return False
diff_summary: list[tuple[str, float]] = []
for key in sorted(keys_a):
ta, tb = sd_a[key], sd_b[key]
if ta.shape != tb.shape:
print(f"❌ shape mismatch at {key}: {tuple(ta.shape)} vs {tuple(tb.shape)}")
return False
# Compare in float to avoid bfloat16 equality quirks.
max_abs = (ta.float() - tb.float()).abs().max().item()
if max_abs > 0:
diff_summary.append((key, max_abs))
if not diff_summary:
print(f"✅ All {len(keys_a)} parameters identical")
return True
# Some keys differ; show worst offenders.
diff_summary.sort(key=lambda kv: kv[1], reverse=True)
print(f"⚠️ {len(diff_summary)} keys differ. Top 10 by max abs diff:")
for key, value in diff_summary[:10]:
print(f" {key:60s} max|Δ| = {value:.3e}")
# Tolerance: bf16 round-trips can introduce ULP-level noise but no real
# change. Allow up to 1e-3 absolute difference; anything larger is a real
# divergence.
worst = diff_summary[0][1]
if worst < 1e-3:
print(f"✅ Worst diff {worst:.3e} is within bf16 round-trip tolerance")
return True
print(f"❌ Worst diff {worst:.3e} exceeds tolerance (1e-3)")
return False
def main() -> int:
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--upstream", required=True, help="Upstream Robometer repo id or local path.")
parser.add_argument("--lerobot", required=True, help="LeRobot-format Robometer repo id or local path.")
args = parser.parse_args()
print(f"Loading upstream: {args.upstream}")
upstream = _load_upstream(args.upstream)
print(f"Loading LeRobot-format: {args.lerobot}")
lerobot = _load_lerobot(args.lerobot)
print("\n=== Config comparison ===")
config_ok = True
for field in [
"base_model_id",
"torch_dtype",
"use_multi_image",
"use_per_frame_progress_token",
"average_temporal_patches",
"frame_pooling",
"frame_pooling_attn_temperature",
"progress_loss_type",
"progress_discrete_bins",
]:
a, b = getattr(upstream.config, field), getattr(lerobot.config, field)
field_ok = a == b
config_ok = config_ok and field_ok
ok = "" if field_ok else ""
print(f" {ok} {field}: upstream={a!r}, lerobot={b!r}")
print("\n=== State-dict comparison ===")
state_dict_ok = compare_state_dicts(upstream, lerobot)
print()
if config_ok and state_dict_ok:
print("🎉 Verification passed — safe to flip the default.")
return 0
print("⛔ Verification failed — DO NOT flip the default.")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -199,13 +199,12 @@ class OpenCVCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
"""
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
self._validate_fourcc()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
set_fourcc_after_size_and_fps = platform.system() == "Windows"
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
self._validate_fourcc()
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -223,11 +222,6 @@ class OpenCVCamera(Camera):
else:
self._validate_fps()
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
# Set FOURCC last to make sure the requested pixel format is actually enforced.
self._validate_fourcc()
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""

View File

@@ -24,7 +24,6 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -32,12 +31,6 @@ from .types import (
PolicyFeature,
RTCAttentionSchedule,
)
from .video import (
VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS,
VideoEncoderConfig,
camera_encoder_defaults,
)
__all__ = [
# Types
@@ -50,16 +43,7 @@ __all__ = [
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
"VideoEncoderConfig",
# Defaults
"camera_encoder_defaults",
# Constants
"VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS",
]

View File

@@ -14,12 +14,10 @@
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from .video import VideoEncoderConfig, camera_encoder_defaults
@dataclass
class DatasetRecordConfig:
@@ -57,9 +55,10 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False

View File

@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from lerobot.transforms import ImageTransformsConfig
from lerobot.utils.import_utils import get_safe_default_video_backend
from lerobot.utils.import_utils import get_safe_default_codec
@dataclass
@@ -34,7 +34,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_video_backend)
video_backend: str = field(default_factory=get_safe_default_codec)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False

View File

@@ -18,8 +18,8 @@ from logging import getLogger
from pathlib import Path
from lerobot import envs, policies # noqa: F401
from lerobot.configs import parser
from . import parser
from .default import EvalConfig
from .policies import PreTrainedConfig

View File

@@ -255,7 +255,8 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
del config_data[field]
else:
del config_data[field]
modified = True
if not modified:
@@ -310,13 +311,7 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
if config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = draccus.parse(
config_class=argtype,
config_path=config_path_cli or config_path,
args=cli_args,
)
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
response = fn(cfg, *args, **kwargs)
return response

View File

@@ -1,206 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
"""``${name}`` placeholder pattern used by both recipe binding-reference
discovery (here) and rendered-message substitution (in ``language_render``)."""
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
# ``stream`` is typed Optional only so the dataclass can keep its
# field ordering, but recipes must always tag every turn with a
# stream — the renderer's ``_validate_rendered`` would reject
# ``None`` later on. Fail at construction so the bad recipe is
# caught at YAML load time rather than at the first sample.
if self.stream is None:
raise ValueError(
f"MessageTurn(role={self.role!r}) is missing a stream — "
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
)
if self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)

View File

@@ -27,13 +27,12 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from lerobot.utils.hub import HubMixin
from .types import PolicyFeature
T = TypeVar("T", bound="RewardModelConfig")
logger = logging.getLogger(__name__)
@@ -91,7 +90,14 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
return None
def get_optimizer_preset(self) -> OptimizerConfig | None:
"""Default optimizer for this reward model, or ``None`` for zero-shot models."""
"""Default optimizer for this reward model, or ``None`` for zero-shot models.
Trainable reward models (e.g. SARM, Classifier) must override this with a
concrete optimizer config. Zero-shot reward models (e.g. Robometer) leave
the default ``None`` — they error out earlier via the
:attr:`~lerobot.rewards.pretrained.PreTrainedRewardModel.is_trainable`
check in ``lerobot-train``.
"""
return None
def get_scheduler_preset(self) -> LRSchedulerConfig | None:

View File

@@ -25,11 +25,11 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig

View File

@@ -1,235 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: We subclass str so that serialization is straightforward
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
"""Video encoder configurations."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from lerobot.utils.import_utils import require_package
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and the chosen video backend.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_VIDEO_CODECS = [
"h264_videotoolbox", # macOS
"hevc_videotoolbox", # macOS
"h264_nvenc", # NVIDIA GPU
"hevc_nvenc", # NVIDIA GPU
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
# ``vcodec``` and ``pix_fmt`` are derived from the video stream directly.
VIDEO_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset(
{"g", "crf", "preset", "fast_decode", "extra_options", "video_backend"}
)
VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
)
@dataclass
class VideoEncoderConfig:
"""Video encoder configuration.
Attributes:
vcodec: Video encoder name. ``"auto"`` is resolved during
construction (HW encoder if available, else ``libsvtav1``).
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
g: GOP size (keyframe interval).
crf: Quality level — mapped to the native quality parameter of the
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
preset: Speed/quality preset. Accepted type is per-codec.
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
set ``tune=fastdecode``. Ignored for other codecs.
video_backend: Python to be used for encoding. Only ``"pyav"``
is currently supported.
extra_options: Free-form dictionary of additional video encoder options
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
"""
vcodec: str = "libsvtav1" # TODO(CarolinePascal): rename to codec ?
pix_fmt: str = "yuv420p"
g: int | None = 2
crf: int | float | None = 30
preset: int | str | None = None
fast_decode: int = 0
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
# two backends (encoding and decoding).
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
if self.preset is None and self.vcodec == "libsvtav1":
self.preset = LIBSVTAV1_DEFAULT_PRESET
self.validate()
@classmethod
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
Missing or ``None`` values fall back to the class defaults.
"""
video_info = video_info or {}
kwargs: dict[str, Any] = {}
for src_key, dst_field in (("video.codec", "vcodec"), ("video.pix_fmt", "pix_fmt")):
value = video_info.get(src_key)
if value is not None:
kwargs[dst_field] = value
for field_name in VIDEO_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{field_name}")
if value is None:
continue
# Persisted as ``{}`` after merges with disagreeing sources — treat as default.
if field_name == "extra_options" and not value:
continue
kwargs[field_name] = value
return cls(**kwargs)
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Return the subset of available encoders based on the specified video backend.
Args:
encoders: List of encoder names to detect. If a string, it is converted to a list.
Returns:
List of available encoder names. If the video backend is not "pyav", returns an empty list.
"""
if self.video_backend == "pyav":
require_package("av", extra="dataset")
from lerobot.datasets import detect_available_encoders_pyav
return detect_available_encoders_pyav(encoders)
return []
def validate(self) -> None:
"""Validate the video encoder configuration."""
if self.video_backend == "pyav":
require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
For ``"auto"``, the first hardware encoder in the preference list that is available is chosen; if none are available, ``libsvtav1`` is used. If the
resolved codec (explicit or after auto-selection) is not available, raises ``ValueError``.
Stream-derived canonical codec names listed in :data:`VIDEO_CODECS_ALIASES` are
rewritten to their corresponding encoder name (e.g. ``"av1"`` → ``"libsvtav1"``).
"""
self.vcodec = VIDEO_CODECS_ALIASES.get(self.vcodec, self.vcodec)
if self.vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if self.vcodec == "auto":
available = self.detect_available_encoders(HW_VIDEO_CODECS)
for encoder in HW_VIDEO_CODECS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
self.vcodec = encoder
return
logger.warning("No hardware encoder available, falling back to software encoder 'libsvtav1'")
self.vcodec = "libsvtav1"
if self.detect_available_encoders(self.vcodec):
logger.info(f"Using video codec: {self.vcodec}")
return
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
def get_codec_options(
self, encoder_threads: int | None = None, as_strings: bool = False
) -> dict[str, Any]:
"""Translate the tuning fields to codec-specific options.
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
Args:
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
For h264/hevc, this is mapped to ``threads``.
Hardware encoders ignore this parameter.
as_strings: If ``True``, casts values to strings.
"""
opts: dict[str, Any] = {}
def set_if(key: str, value: Any) -> None:
if value is not None:
opts[key] = value if not as_strings else str(value)
# GOP size is not a codec-specific option, so it is always set.
set_if("g", self.g)
if self.vcodec == "libsvtav1":
set_if("crf", self.crf)
set_if("preset", self.preset)
svtav1_parts: list[str] = []
if self.fast_decode is not None:
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
if encoder_threads is not None:
svtav1_parts.append(f"lp={encoder_threads}")
if svtav1_parts:
opts["svtav1-params"] = ":".join(svtav1_parts)
elif self.vcodec in ("h264", "hevc"):
set_if("crf", self.crf)
set_if("preset", self.preset)
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("threads", encoder_threads)
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
opts["rc"] = 0
set_if("qp", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "h264_vaapi":
set_if("qp", self.crf)
elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf)
set_if("preset", self.preset)
else:
set_if("crf", self.crf)
set_if("preset", self.preset)
# Extra options are merged last but never override structured fields (values are kept as given).
for k, v in self.extra_options.items():
if k not in opts:
set_if(k, v)
return opts
def camera_encoder_defaults() -> VideoEncoderConfig:
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
return VideoEncoderConfig()

View File

@@ -31,25 +31,15 @@ from .dataset_tools import (
modify_features,
modify_tasks,
recompute_stats,
reencode_dataset,
remove_feature,
split_dataset,
)
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
@@ -63,19 +53,12 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"check_video_encoder_parameters_pyav",
"detect_available_encoders_pyav",
"add_features",
"aggregate_datasets",
"aggregate_pipeline_dataset_features",
@@ -83,7 +66,6 @@ __all__ = [
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",
@@ -92,7 +74,6 @@ __all__ = [
"modify_features",
"modify_tasks",
"recompute_stats",
"reencode_dataset",
"remove_feature",
"resolve_delta_timestamps",
"safe_stop_image_writer",

View File

@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import shutil
from pathlib import Path
@@ -24,11 +23,9 @@ import datasets
import pandas as pd
import tqdm
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from .compute_stats import aggregate_stats
from .dataset_metadata import LeRobotDatasetMetadata
from .feature_utils import features_equal_for_merge, get_hf_features_from_features
from .feature_utils import get_hf_features_from_features
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
@@ -49,54 +46,11 @@ from .utils import (
from .video_utils import concatenate_video_files, get_video_duration_in_s
def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]:
"""Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to merge.
Returns:
dict: A dictionary of merged video feature info.
"""
merged_info = copy.deepcopy(all_metadata[0].features)
video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"]
for vk in video_keys:
video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata]
base_video_info = video_infos[0]
merged_encoder_info: dict = {}
fallback_keys: list[str] = []
for info_key in VIDEO_ENCODER_INFO_KEYS:
values = [info.get(info_key, None) for info in video_infos]
first_value = values[0]
all_match = all(v == first_value for v in values[1:])
if all_match:
merged_encoder_info[info_key] = first_value
else:
fallback_keys.append(info_key)
merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None
if fallback_keys:
logging.warning(
f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. "
f"Setting these keys to null: {fallback_keys}.",
)
merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info}
# TODO(CarolinePascal): make this variable once we have support for other video backends.
merged_info[vk]["info"]["video.video_backend"] = "pyav"
return merged_info
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
"""Validates that all dataset metadata have consistent properties.
Ensures all datasets have the same fps, robot_type, and features to guarantee
compatibility when aggregating them into a single dataset.
Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``.
Args:
all_metadata: List of LeRobotDatasetMetadata objects to validate.
@@ -120,7 +74,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
raise ValueError(
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
)
if not features_equal_for_merge(features, meta.features):
if features != meta.features:
raise ValueError(
f"Same features is expected, but got features={meta.features} instead of {features}."
)
@@ -320,8 +274,7 @@ def aggregate_datasets(
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
fps, robot_type, _ = validate_all_metadata(all_metadata)
features = merge_video_feature_info_for_aggregate(all_metadata)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
dst_meta = LeRobotDatasetMetadata.create(
@@ -379,6 +332,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -460,11 +414,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
# TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails
concatenate_video_files(
[dst_path, src_path],
dst_path,
compatibility_check=True,
)
# Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration

View File

@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] in {"string", "language"}:
if features[key]["dtype"] == "string":
continue
if features[key]["dtype"] in ["image", "video"]:

View File

@@ -24,7 +24,6 @@ import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.configs import VideoEncoderConfig
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.feature_utils import _validate_feature_names
from lerobot.utils.utils import flatten_dict
@@ -36,12 +35,12 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_stats,
write_tasks,
)
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
from .utils import (
DEFAULT_EPISODES_PATH,
check_version_compatibility,
@@ -177,6 +176,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -342,49 +342,6 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def has_language_columns(self) -> bool:
"""Return ``True`` if the dataset declares any language column.
Used to gate language-aware code paths (collate, render step) so
unannotated datasets keep PyTorch's default collate behavior.
"""
return any(col in self.features for col in LANGUAGE_COLUMNS)
@property
def tools(self) -> list[dict]:
"""OpenAI-style tool schemas declared by this dataset.
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
can mutate the result safely. Falls back to
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
``say`` schema) when the dataset doesn't declare any — that way
unannotated datasets and chat-template consumers
(``apply_chat_template(messages, tools=meta.tools)``) keep
working out of the box.
Implementations live under :mod:`lerobot.tools` (one file per
tool); see ``docs/source/tools.mdx`` for the authoring guide.
"""
declared = self.info.tools
if declared:
return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS]
@tools.setter
def tools(self, value: list[dict] | None) -> None:
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
Writes ``value`` into the on-disk ``info.json`` (or clears the
``tools`` key when ``value`` is ``None`` or empty), then reloads
``self.info`` so the in-memory metadata matches what's on disk.
Saves callers from hand-editing ``info.json`` and re-instantiating
the metadata object.
"""
self.info.tools = [dict(t) for t in value] if value else None
write_info(self.info, self.root)
self.info = load_info(self.root)
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -577,23 +534,10 @@ class LeRobotDatasetMetadata:
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(
self,
video_key: str | None = None,
camera_encoder: VideoEncoderConfig | None = None,
) -> None:
"""Populate per-feature video info in ``info.json``.
def update_video_info(self, video_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
Args:
video_key: If provided, only update this video key. Otherwise update
all video keys in the dataset.
camera_encoder: Encoder configuration used to produce the
videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`).
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
@@ -602,7 +546,7 @@ class LeRobotDatasetMetadata:
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
self.info.features[key]["info"] = get_video_info(video_path)
def update_chunk_settings(
self,
@@ -713,6 +657,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(

View File

@@ -295,4 +295,9 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item

View File

@@ -26,7 +26,7 @@ This module provides utilities for:
import logging
import shutil
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import datasets
@@ -36,7 +36,6 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
from lerobot.utils.utils import flatten_dict
@@ -61,14 +60,9 @@ from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
VIDEO_DIR,
update_chunk_file_indices,
)
from .video_utils import (
encode_video_frames,
get_video_info,
reencode_video,
)
from .video_utils import encode_video_frames, get_video_info
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -101,11 +95,6 @@ def delete_episodes(
) -> LeRobotDataset:
"""Delete episodes from a LeRobotDataset and create a new dataset.
Video segments that need re-encoding (because the source file mixes kept and
deleted episodes) are re-encoded with the source dataset's existing encoder
settings — read back from ``meta/info.json`` — so the output dataset stays
consistent with its own metadata.
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
@@ -168,11 +157,6 @@ def split_dataset(
) -> dict[str, LeRobotDataset]:
"""Split a LeRobotDataset into multiple smaller datasets.
Video segments that need re-encoding (because the source file mixes episodes
that fall into different splits) are re-encoded with the source dataset's
existing encoder settings — read back from ``meta/info.json`` — so each
output split stays consistent with its own metadata.
Args:
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
@@ -594,7 +578,8 @@ def _keep_episodes_from_video_with_av(
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
fps: float,
camera_encoder: VideoEncoderConfig,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
@@ -608,7 +593,8 @@ def _keep_episodes_from_video_with_av(
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
camera_encoder: Video encoder settings used to re-encode the kept frames.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
"""
from fractions import Fraction
@@ -633,13 +619,12 @@ def _keep_episodes_from_video_with_av(
# Convert fps to Fraction for PyAV compatibility.
fps_fraction = Fraction(fps).limit_denominator(1000)
codec_options = camera_encoder.get_codec_options(as_strings=True)
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
v_out = out.add_stream(vcodec, rate=fps_fraction)
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
v_out.width = v_in.codec_context.width
v_out.height = v_in.codec_context.height
v_out.pix_fmt = camera_encoder.pix_fmt
v_out.pix_fmt = pix_fmt
# Set time_base to match the frame rate for proper timestamp handling.
v_out.time_base = Fraction(1, int(fps))
@@ -702,14 +687,14 @@ def _copy_and_reindex_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int],
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> dict[int, dict]:
"""Copy and filter video files, only re-encoding files with deleted episodes.
For video files that only contain kept episodes, we copy them directly.
For files with mixed kept/deleted episodes, we use PyAV filters to efficiently
re-encode only the desired segments. The encoder used for re-encoding is
derived per video key from the source dataset's ``meta/info.json`` so the
destination metadata keeps describing the videos accurately.
re-encode only the desired segments.
Args:
src_dataset: Source dataset to copy from
@@ -726,9 +711,6 @@ def _copy_and_reindex_videos(
for video_key in src_dataset.meta.video_keys:
logging.info(f"Processing videos for {video_key}")
camera_encoder = VideoEncoderConfig.from_video_info(
src_dataset.meta.info.features.get(video_key, {}).get("info")
)
if dst_meta.video_path is None:
raise ValueError("Destination metadata has no video_path defined")
@@ -810,7 +792,8 @@ def _copy_and_reindex_videos(
dst_video_path,
episodes_to_keep_ranges,
src_dataset.meta.fps,
camera_encoder,
vcodec,
pix_fmt,
)
cumulative_ts = 0.0
@@ -1281,7 +1264,11 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int],
temp_dir: Path,
fps: int,
camera_encoder: VideoEncoderConfig,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
@@ -1295,7 +1282,11 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
camera_encoder: Video encoder settings used for calibration encoding.
vcodec: Video codec (libsvtav1, h264, hevc).
pix_fmt: Pixel format (yuv420p, etc.).
g: GOP size (group of pictures).
crf: Constant Rate Factor (quality).
fast_decode: Fast decode tuning parameter.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
@@ -1331,7 +1322,11 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
camera_encoder=camera_encoder,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
@@ -1649,7 +1644,11 @@ def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path | None = None,
repo_id: str | None = None,
camera_encoder: VideoEncoderConfig | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
@@ -1664,8 +1663,11 @@ def convert_image_to_video_dataset(
dataset: The source LeRobot dataset with images
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
camera_encoder: Video encoder settings
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
@@ -1674,9 +1676,6 @@ def convert_image_to_video_dataset(
Returns:
New LeRobotDataset with images encoded as videos
"""
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
@@ -1700,10 +1699,7 @@ def convert_image_to_video_dataset(
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
@@ -1773,7 +1769,11 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
camera_encoder=camera_encoder,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
)
logging.info(f"Processing camera: {img_key}")
@@ -1815,7 +1815,11 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
camera_encoder=camera_encoder,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
@@ -1861,9 +1865,7 @@ def convert_image_to_video_dataset(
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info.features[img_key]["info"] = get_video_info(
video_path, camera_encoder=camera_encoder
)
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
write_info(new_meta.info, new_meta.root)
@@ -1886,83 +1888,3 @@ def convert_image_to_video_dataset(
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def _reencode_video_worker(args: tuple) -> Path:
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
video_path, camera_encoder, encoder_threads = args
reencode_video(
input_video_path=video_path,
output_video_path=video_path,
camera_encoder=camera_encoder,
encoder_threads=encoder_threads,
overwrite=True,
)
return video_path
def reencode_dataset(
dataset: LeRobotDataset,
camera_encoder: VideoEncoderConfig,
encoder_threads: int | None = None,
num_workers: int | None = None,
) -> LeRobotDataset:
"""Re-encode every video in a dataset with a new set of encoding parameters.
Videos are re-encoded in-place and the video information in ``info.json`` is refreshed.
Args:
dataset: An existing :class:`LeRobotDataset` whose videos will be
re-encoded.
camera_encoder: Target encoder configuration applied to every video
file.
encoder_threads: Per-encoder thread count forwarded to
:func:`reencode_video`. ``None`` lets the codec decide.
num_workers: Number of parallel processes. ``None`` or ``0`` means
sequential (no multiprocessing); ``1+`` spawns a
:class:`~concurrent.futures.ProcessPoolExecutor`.
Returns:
The same :class:`LeRobotDataset` instance with its metadata updated
on disk.
"""
meta = dataset.meta
video_paths_list = []
# Only re-encode if the videos are not already encoded with the given video encoding parameters
for video_key in meta.video_keys:
current_info = meta.info.features[video_key].get("info", {})
current_encoder = VideoEncoderConfig.from_video_info(current_info)
if current_encoder != camera_encoder:
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
else:
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
if len(video_paths_list) == 0:
logging.warning("Dataset has no videos to re-encode.")
return dataset
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
if num_workers and num_workers > 1:
with ProcessPoolExecutor(max_workers=num_workers) as pool:
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Re-encoding videos",
):
future.result()
else:
for args in tqdm(worker_args, desc="Re-encoding videos"):
_reencode_video_worker(args)
# Refresh video info in metadata for every video key.
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(0, vid_key)
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
write_info(meta.info, meta.root)
logging.info("Dataset metadata updated.")
return dataset

View File

@@ -31,8 +31,6 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata
from .feature_utils import (
@@ -67,19 +65,14 @@ def _encode_video_worker(
episode_index: int,
root: Path,
fps: int,
camera_encoder: VideoEncoderConfig | None = None,
vcodec: str = "libsvtav1",
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir,
temp_path,
fps,
camera_encoder=camera_encoder,
encoder_threads=encoder_threads,
overwrite=True,
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
)
shutil.rmtree(img_dir)
return temp_path
@@ -96,22 +89,20 @@ class DatasetWriter:
self,
meta: LeRobotDatasetMetadata,
root: Path,
camera_encoder: VideoEncoderConfig | None,
vcodec: str,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0,
):
"""Initialize the writer with metadata, codec, and encoder config.
"""Initialize the writer with metadata, codec, and encoding config.
Args:
meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence).
root: Local dataset root directory.
camera_encoder: Video encoder settings applied to all cameras.
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
encoder_threads: Threads per encoder instance. ``None`` for auto.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
@@ -120,7 +111,7 @@ class DatasetWriter:
"""
self._meta = meta
self._root = root
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._vcodec = vcodec
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
@@ -250,14 +241,7 @@ class DatasetWriter:
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
stacked_values = np.stack(episode_buffer[key])
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
stacked_values = stacked_values.reshape(episode_length)
episode_buffer[key] = stacked_values
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
@@ -300,7 +284,7 @@ class DatasetWriter:
episode_index,
self._root,
self._meta.fps,
self._camera_encoder,
self._vcodec,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
@@ -511,7 +495,7 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
self._meta.update_video_info(video_key)
write_info(self._meta.info, self._meta.root)
metadata = {
@@ -580,12 +564,7 @@ class DatasetWriter:
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker(
video_key,
episode_index,
self._root,
self._meta.fps,
self._camera_encoder,
self._encoder_threads,
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
)
def close_writer(self) -> None:

View File

@@ -13,23 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pprint import pformat
import datasets
import numpy as np
from PIL import Image as PILImage
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -54,13 +46,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -122,41 +108,6 @@ def create_empty_dataset_info(
)
def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool:
"""Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation.
For video features, keys under ``info`` related to video encoding parameters are ignored during
comparison as they do not prevent aggregation.
"""
def _without_encoder_info_keys(feature: dict) -> dict:
filtered = dict(feature)
filtered_info = filtered.get("info")
if isinstance(filtered_info, dict):
filtered["info"] = {
info_key: info_value
for info_key, info_value in filtered_info.items()
if info_key not in VIDEO_ENCODER_INFO_KEYS
}
return filtered
if set(features_a) != set(features_b):
return False
for key in features_a:
fa_key = features_a[key]
fb_key = features_b[key]
if fa_key.get("dtype") != fb_key.get("dtype"):
return False
if fa_key.get("dtype") != "video":
if fa_key != fb_key:
return False
continue
if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key):
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
@@ -291,8 +242,6 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return validate_feature_language(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
@@ -372,30 +321,6 @@ def validate_feature_string(name: str, value: str) -> str:
return ""
def validate_feature_language(name: str, value) -> str:
"""Validate a feature that is expected to hold language annotations.
Language columns (``language_persistent`` / ``language_events``) are
populated after recording by the annotation pipeline, not at record time.
Any value supplied here is dropped before the frame is written, so a
non-empty value almost certainly signals a mistake. We warn rather than
fail to keep recording resilient.
Args:
name (str): The name of the feature.
value: The value to validate.
Returns:
str: Always an empty string — language values are non-fatal.
"""
if value is not None:
logging.warning(
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
f"not at record time. The provided value will be dropped."
)
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
"""Validate the episode buffer before it's written to disk.

View File

@@ -31,10 +31,10 @@ from torchvision import transforms
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
from .language import LANGUAGE_COLUMNS
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
@@ -186,6 +186,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -257,13 +265,11 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in LANGUAGE_COLUMNS:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None or isinstance(first_item, dict):
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -298,9 +304,8 @@ def item_to_torch(item: dict) -> dict:
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
skip_keys = {"task", *LANGUAGE_COLUMNS}
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item

View File

@@ -1,242 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {
"subtask",
"plan",
"memory",
"motion",
"interjection",
"vqa",
"trace",
"task_aug",
}
# Project-local styles can be registered at import time by appending to
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
# validation. Empty by default — populate from a downstream module that
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
# the new style's column.
EXTENDED_STYLES: set[str] = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
# is intentionally NOT in this set: motion primitives are described in
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
# view-dependent. The ``camera`` field nevertheless lives on
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
# behave symmetrically across the two columns; persistent rows simply
# always have ``camera=None`` in practice today.
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
uses for frame data.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float32(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float32"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def is_view_dependent_style(style: str | None) -> bool:
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
return style in VIEW_DEPENDENT_STYLES
def validate_camera_field(style: str | None, camera: str | None) -> None:
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
a non-view-dependent style carries one. Pipeline writers and the validator
should call this on every emitted row.
"""
if is_view_dependent_style(style):
if not camera:
raise ValueError(
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
f"field referencing an 'observation.images.*' feature key."
)
elif camera is not None:
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
# --- Tool registry --------------------------------------------------------
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
# of OpenAI-style function schemas. The runtime / training stack reads them
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
# fallback when the dataset doesn't declare any). Implementations live
# under :mod:`lerobot.tools` (one file per tool); see
# ``docs/source/tools.mdx`` for the authoring guide.
SAY_TOOL_SCHEMA: dict = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
"""Canonical schema for the ``say`` tool emitted by the steerable
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
writer, PR 3's runtime tool registry, and the dataset visualizer all
import this constant rather than duplicating the dict."""
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
chat-template consumers (``apply_chat_template(messages, tools=...)``)
keep working out of the box."""
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")

View File

@@ -1,545 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
from lerobot.utils.utils import unwrap_scalar
from .language import LANGUAGE_PERSISTENT, column_for_style
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
``camera`` selector. Only valid for persistent styles.
"""
_validate_persistent_resolver("active_at", style)
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) <= t
]
if not matches:
return None
latest_ts = max(_timestamp(row) for row in matches)
return _select_one(
[row for row in matches if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
camera=camera,
)
EMITTED_AT_TOLERANCE_S = 0.1
"""Half-window for matching persistent rows to a frame timestamp in
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
is also a float32 from parquet, so in the ideal hot path an exact match
would suffice — but any caller that derives ``t`` arithmetically (e.g.
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
common arithmetic drift without admitting frames that are visibly far
apart at typical control rates (30100 Hz). This does mean two persistent
rows of the same selector emitted within 0.1 s of each other cannot be
told apart by ``emitted_at`` — acceptable because persistent annotations
(subtask / plan / memory transitions) change on a human-action timescale,
not at the camera frame rate."""
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
we use a tolerance instead of bit-equality). For event styles, the
``events`` list is assumed to come from the dataset row at frame ``t``
(event rows carry no timestamp of their own), so all matching event rows
are considered emitted at ``t``. ``camera`` filters by the row's
``camera`` field — required to disambiguate when multiple view-dependent
rows share ``(t, role)`` across cameras.
"""
if column_for_style(style) == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
]
else:
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions before the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions after the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(
task: str | None,
dataset_ctx: Any | None,
*,
persistent: Sequence[LanguageRow] = (),
sample_idx: int = 0,
) -> str | None:
"""Return the task string for ``sample_idx``.
Resolution order:
1. Explicit ``task`` override (caller-supplied) wins.
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
deterministically pick one by ``sample_idx`` so each frame of an
episode rotates through the available rephrasings across an epoch.
This realizes Xiao 2022 / CAST-style task-prompt diversity without
changing ``meta/tasks.parquet`` and without forcing recipes to opt
in: ``${task}`` automatically picks a rephrasing when one exists,
and falls back to the canonical task otherwise. Recipes that want
the literal canonical task can override the binding.
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
backed by ``meta/tasks.parquet``).
"""
if task is not None:
return task
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
return str(chosen)
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
if name == "emitted_at":
return emitted_at(t, persistent=persistent, events=events, **kwargs)
if name == "active_at":
return active_at(t, persistent=persistent, **kwargs)
if name == "nth_prev":
return nth_prev(t, persistent=persistent, **kwargs)
if name == "nth_next":
return nth_next(t, persistent=persistent, **kwargs)
raise ValueError(f"Unknown language resolver: {name!r}")
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
# ``stream`` is enforced non-None at MessageTurn construction time
# (see ``MessageTurn.__post_init__``), so a missing stream here would
# mean the dataclass invariant was bypassed; no need to re-check.
def _nth_relative(
name: str,
t: float,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(name, style)
if abs(offset) < 1:
raise ValueError(f"{name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
key=_row_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{name} requires a persistent style.")
if column_for_style(style) != LANGUAGE_PERSISTENT:
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
and (camera is None or row.get("camera") == camera)
]
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the resolver is ambiguous.
Multiple matches always raise — even when the caller already passed
some selectors — because remaining ambiguity means the data has
several rows that look identical to the resolver and the caller
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
rows shared across cameras).
"""
if not rows:
return None
if len(rows) > 1:
raise ValueError(
f"Ambiguous resolver for style={style!r} role={role!r} "
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
f"Add a selector that distinguishes them."
)
return rows[0]
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Stable sort key for both persistent and event rows.
Event rows lack ``timestamp`` (it is implicit in the frame), so default
to ``0.0`` — within a single frame all event rows share the same sort
bucket and are tiebroken by ``(style, role)``.
"""
timestamp = row.get("timestamp")
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
return (ts, row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
return float(unwrap_scalar(row["timestamp"]))
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized

View File

@@ -24,7 +24,6 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.configs import VideoEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
@@ -37,7 +36,8 @@ from .utils import (
)
from .video_utils import (
StreamingVideoEncoder,
get_safe_default_video_backend,
get_safe_default_codec,
resolve_vcodec,
)
logger = logging.getLogger(__name__)
@@ -59,10 +59,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
return_uint8: bool = False,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
vcodec: str = "libsvtav1",
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -183,15 +183,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
is used by the writer.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
streaming encoding. Defaults to 30 (~1s at 30fps).
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
libsvtav1 and 'threads' for h264/hevc.
Note:
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
@@ -206,9 +207,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._video_backend = video_backend if video_backend else get_safe_default_codec()
self._return_uint8 = return_uint8
self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
if self._requested_root is not None:
@@ -271,15 +273,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(self.meta.video_keys) > 0:
streaming_enc = self._build_streaming_encoder(
self.meta.fps,
camera_encoder,
encoder_queue_maxsize,
encoder_threads,
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
)
self.writer = DatasetWriter(
meta=self.meta,
root=self.root,
camera_encoder=camera_encoder,
vcodec=self._vcodec,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -321,13 +320,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
@staticmethod
def _build_streaming_encoder(
fps: int,
camera_encoder: VideoEncoderConfig | None,
vcodec: str,
encoder_queue_maxsize: int,
encoder_threads: int | None,
) -> StreamingVideoEncoder:
return StreamingVideoEncoder(
fps=fps,
camera_encoder=camera_encoder,
vcodec=vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
@@ -644,7 +647,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0,
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
vcodec: str = "libsvtav1",
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -675,20 +678,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend (used when reading back).
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos. ``1`` means encode immediately.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
``'h264'``, ``'hevc'``, ``'auto'``.
metadata_buffer_size: Number of episode metadata records to buffer
before flushing to parquet.
streaming_encoding: If ``True``, encode video frames in real-time
during capture instead of writing images first.
encoder_queue_maxsize: Max buffered frames per camera when using
streaming encoding.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns:
A new :class:`LeRobotDataset` in write mode.
"""
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -709,23 +712,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
# Reader is lazily created on first access (write-only mode)
obj.reader = None
# Create writer
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
)
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
camera_encoder=camera_encoder,
vcodec=vcodec,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -748,12 +751,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
vcodec: str = "libsvtav1",
image_writer_processes: int = 0,
image_writer_threads: int = 0,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
) -> "LeRobotDataset":
"""Resume recording on an existing dataset.
@@ -776,15 +779,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend for reading back data.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
vcodec: Video codec for encoding.
image_writer_processes: Subprocesses for async image writing.
image_writer_threads: Threads for async image writing.
streaming_encoding: If ``True``, encode video in real-time during
capture.
encoder_queue_maxsize: Max buffered frames per camera for streaming.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns:
A :class:`LeRobotDataset` in write mode, ready to append episodes.
@@ -795,6 +796,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
"the shared cache. Please provide a local directory path."
)
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj._requested_root = Path(root)
@@ -803,9 +805,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
if obj._requested_root is not None:
obj._requested_root.mkdir(exist_ok=True, parents=True)
@@ -814,22 +818,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.meta = LeRobotDatasetMetadata(
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
)
obj._encoder_threads = encoder_threads
obj.root = obj.meta.root
# Reader is lazily created on first access (write-only mode)
obj.reader = None
# Create writer for appending
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
camera_encoder=camera_encoder,
vcodec=vcodec,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,

View File

@@ -1,174 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
Checks degrade to a no-op when the target codec isn't available locally.
"""
import functools
import logging
from typing import Any
import av
logger = logging.getLogger(__name__)
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
try:
return av.codec.Codec(vcodec, "w")
except Exception:
return None
@functools.cache
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
codec = get_codec(vcodec)
if codec is None:
return {}
return {opt.name: opt for opt in codec.descriptor.options}
@functools.cache
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
codec = get_codec(vcodec)
if codec is None:
return ()
return tuple(fmt.name for fmt in (codec.video_formats or []))
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
Each name is probed directly via :func:`get_codec`; input order is preserved.
"""
if isinstance(encoders, str):
encoders = [encoders]
available: list[str] = []
for name in encoders:
codec = get_codec(name)
if codec is not None and codec.type == "video":
available.append(name)
else:
logger.debug("encoder '%s' not available as video encoder", name)
return available
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
type_name = opt.type.name
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
if isinstance(value, bool):
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
elif isinstance(value, str):
try:
num_val = float(value)
except ValueError as e:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
) from e
elif isinstance(value, (float, int)):
num_val = value
else:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
# Check integer type compatibility
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
raise ValueError(
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
)
# Check numeric range compatibility
lo, hi = float(opt.min), float(opt.max)
if lo < hi and not (lo <= num_val <= hi):
raise ValueError(
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
)
elif type_name == "STRING":
if isinstance(value, bool):
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
if isinstance(value, str):
str_val = value
elif isinstance(value, (int, float)):
str_val = str(value)
else:
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
# Check string choice compatibility
choices = [c.name for c in (opt.choices or [])]
if choices and str_val not in choices:
raise ValueError(
f"{label}={str_val!r} is not a supported choice for codec "
f"{vcodec!r}; valid choices: {choices}"
)
else:
return
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
formats = _get_codec_video_formats(vcodec)
if formats and pix_fmt not in formats:
raise ValueError(
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
f"supported pixel formats: {list(formats)}"
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
for key, value in codec_options.items():
# GOP size is not a codec-specific option, it has to be validated separately.
if key == "g":
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
continue
if key not in supported_options:
continue
_check_option_value(vcodec, key, value, supported_options[key])
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
ValueError: on the first incompatibility encountered.
"""
options = _get_codec_options_by_name(vcodec)
if not options:
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
_check_pixel_format(vcodec, pix_fmt)
_check_codec_options(vcodec, codec_options)

View File

@@ -88,6 +88,7 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -129,9 +130,6 @@ class DatasetInfo:
# Optional metadata
robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict)
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
tools: list[dict] | None = None
def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation
@@ -153,15 +151,11 @@ class DatasetInfo:
"""Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them.
Drops ``tools`` when unset so existing datasets keep a clean
``info.json``.
"""
d = dataclasses.asdict(self)
for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"])
if d.get("tools") is None:
d.pop("tools", None)
return d
@classmethod

View File

@@ -17,14 +17,12 @@ import contextlib
import glob
import importlib
import logging
import os
import queue
import shutil
import tempfile
import threading
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from fractions import Fraction
from pathlib import Path
from threading import Lock
@@ -38,14 +36,86 @@ import torch
from datasets.features.features import register_feature
from PIL import Image
from lerobot.configs import (
VideoEncoderConfig,
camera_encoder_defaults,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from lerobot.utils.import_utils import get_safe_default_codec
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
"h264_videotoolbox", # macOS
"hevc_videotoolbox", # macOS
"h264_nvenc", # NVIDIA GPU
"hevc_nvenc", # NVIDIA GPU
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
def _get_codec_options(
vcodec: str,
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
) -> dict:
"""Build codec-specific options dict for video encoding."""
options = {}
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
options["g"] = str(g)
# Quality control (codec-specific parameter names)
if crf is not None:
if vcodec in ("h264", "hevc", "libsvtav1"):
options["crf"] = str(crf)
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
quality = max(1, min(100, int(100 - crf * 2)))
options["q:v"] = str(quality)
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
options["rc"] = "constqp"
options["qp"] = str(crf)
elif vcodec in ("h264_vaapi",):
options["qp"] = str(crf)
elif vcodec in ("h264_qsv",):
options["global_quality"] = str(crf)
# Preset (only for libsvtav1)
if vcodec == "libsvtav1":
options["preset"] = str(preset) if preset is not None else "12"
return options
def detect_available_hw_encoders() -> list[str]:
"""Probe PyAV/FFmpeg for available hardware video encoders."""
available = []
for codec_name in HW_ENCODERS:
try:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
return available
def resolve_vcodec(vcodec: str) -> str:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logger.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
return encoder
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
def decode_video_frames(
video_path: Path | str,
@@ -73,7 +143,7 @@ def decode_video_frames(
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_video_backend()
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "pyav":
@@ -193,70 +263,15 @@ def decode_video_frames_pyav(
return closest_frames
DEFAULT_DECODER_CACHE_SIZE = 100
"""Default LRU capacity for :class:`VideoDecoderCache`.
Sized to comfortably hold a small rolling window of episodes worth of decoders
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
to the constructor (``None`` restores the legacy unbounded behaviour).
"""
def _default_max_cache_size() -> int | None:
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
if raw is None:
return DEFAULT_DECODER_CACHE_SIZE
raw = raw.strip().lower()
if raw in ("", "none", "unbounded", "-1"):
return None
try:
value = int(raw)
except ValueError as e:
raise ValueError(
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
) from e
if value <= 0:
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
return value
class VideoDecoderCache:
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
backing it. When the cache is full and a new path is requested, the
least-recently-used entry is evicted and its file handle is closed. This
bounds host-RAM growth when iterating over datasets with many distinct
video files (otherwise each ``DataLoader`` worker pins every decoder it has
ever opened until the process exits).
Args:
max_size: Maximum number of decoders to retain. ``None`` disables
eviction and restores legacy unbounded behaviour. Defaults to the
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
:data:`DEFAULT_DECODER_CACHE_SIZE`.
"""
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
raise ValueError(f"max_size must be positive or None; got {max_size}")
self.max_size: int | None = max_size # type: ignore[assignment]
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
self._lock = Lock()
def __contains__(self, video_path: object) -> bool:
with self._lock:
return str(video_path) in self._cache
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
"""Get a cached decoder or create a new one."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
@@ -268,36 +283,22 @@ class VideoDecoderCache:
video_path = str(video_path)
with self._lock:
entry = self._cache.get(video_path)
if entry is not None:
self._cache.move_to_end(video_path)
return entry[0]
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
# Evict LRU entries until we are back under the cap. We close
# evicted file handles immediately; the associated ``VideoDecoder``
# is released to the GC when its last reference goes away.
if self.max_size is not None:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
with contextlib.suppress(Exception):
evicted_handle.close()
return decoder
return self._cache[video_path][0]
def clear(self):
"""Clear the cache and close all file handles."""
"""Clear the cache and close file handles."""
with self._lock:
for _, file_handle in self._cache.values():
with contextlib.suppress(Exception):
file_handle.close()
file_handle.close()
self._cache.clear()
def size(self) -> int:
@@ -406,17 +407,18 @@ def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
*,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
preset: int | None = None,
encoder_threads: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
vcodec = camera_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt
vcodec = resolve_vcodec(vcodec)
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -427,18 +429,42 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logger.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
# Get input frames
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
# Define video codec options
video_options = _get_codec_options(vcodec, g, crf, preset)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if encoder_threads is not None:
if vcodec == "libsvtav1":
lp_param = f"lp={encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(encoder_threads)
# Set logging level
if log_level is not None:
@@ -474,97 +500,8 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def reencode_video(
input_video_path: Path | str,
output_video_path: Path | str,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
) -> None:
"""Re-encode a video file using the given encoder configuration.
Args:
input_video_path: Existing video file to read.
output_video_path: Path for the re-encoded file.
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
"""
camera_encoder = camera_encoder or camera_encoder_defaults()
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logger.warning(f"Video file already exists: {output_video_path}. Skipping re-encode.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
vcodec = camera_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
if log_level is not None:
logging.getLogger("libav").setLevel(log_level)
try:
with av.open(input_video_path, mode="r") as src:
try:
in_stream = src.streams.video[0]
except IndexError as e:
raise ValueError(f"No video stream in {input_video_path}") from e
fps = (
in_stream.base_rate
) # We allow fractional fps though LeRobotDataset only supports integer fps
width = int(in_stream.width)
height = int(in_stream.height)
with av.open(
tmp_output_video_path,
mode="w",
options={
"movflags": "faststart"
}, # faststart is to move the metadata to the beginning of the file to speed up loading
) as dst:
out_stream = dst.add_stream(vcodec, fps, options=video_options)
out_stream.pix_fmt = pix_fmt
out_stream.width = width
out_stream.height = height
for frame in src.decode(in_stream):
frame = frame.reformat(width=width, height=height, format=pix_fmt)
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
packet = out_stream.encode()
if packet:
dst.mux(packet)
shutil.move(tmp_output_video_path, output_video_path)
except Exception:
Path(tmp_output_video_path).unlink(missing_ok=True)
raise
finally:
if log_level is not None:
av.logging.restore_default_callback()
if not output_video_path.exists():
raise OSError(f"Video re-encoding did not work. File not found: {output_video_path}.")
def concatenate_video_files(
input_video_paths: list[Path | str],
output_video_path: Path,
overwrite: bool = True,
compatibility_check: bool = False,
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
):
"""
Concatenate multiple video files into a single video file using pyav.
@@ -577,7 +514,6 @@ def concatenate_video_files(
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
compatibility_check: Whether to check if the input videos are compatible. Default is False.
Note:
- Creates a temporary directory for intermediate files that is cleaned up after use.
@@ -596,22 +532,6 @@ def concatenate_video_files(
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
# This check may be skipped at recording time as videos are encoded with the same encoder config.
if compatibility_check:
reference_video_info = get_video_info(input_video_paths[0])
for input_path in input_video_paths[1:]:
video_info = get_video_info(input_path)
if (
video_info["video.height"] != reference_video_info["video.height"]
or video_info["video.width"] != reference_video_info["video.width"]
or video_info["video.fps"] != reference_video_info["video.fps"]
or video_info["video.codec"] != reference_video_info["video.codec"]
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
):
raise ValueError(
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
)
# Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
@@ -678,20 +598,26 @@ class _CameraEncoderThread(threading.Thread):
fps: int,
vcodec: str,
pix_fmt: str,
codec_options: dict[str, str],
g: int | None,
crf: int | None,
preset: int | None,
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
encoder_threads: int | None = None,
):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.codec_options = codec_options
self.g = g
self.crf = crf
self.preset = preset
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
@@ -727,9 +653,19 @@ class _CameraEncoderThread(threading.Thread):
# Open container on first frame (to get width/height)
if container is None:
height, width = frame_data.shape[:2]
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
if self.encoder_threads is not None:
if self.vcodec == "libsvtav1":
lp_param = f"lp={self.encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(self.encoder_threads)
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = width
output_stream.height = height
@@ -795,24 +731,22 @@ class StreamingVideoEncoder:
def __init__(
self,
fps: int,
camera_encoder: VideoEncoderConfig | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
Args:
fps: Frames per second for the output videos.
camera_encoder: Video encoder settings applied to all cameras.
When ``None``, :func:`camera_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global setting).
``None`` lets the codec decide.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
"""
self.fps = fps
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._encoder_threads = encoder_threads
self.vcodec = resolve_vcodec(vcodec)
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.queue_maxsize = queue_maxsize
self.encoder_threads = encoder_threads
self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {}
@@ -843,17 +777,18 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
vcodec = self._camera_encoder.vcodec
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=vcodec,
pix_fmt=self._camera_encoder.pix_fmt,
codec_options=codec_options,
vcodec=self.vcodec,
pix_fmt=self.pix_fmt,
g=self.g,
crf=self.crf,
preset=self.preset,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self.encoder_threads,
)
encoder_thread.start()
@@ -1058,18 +993,8 @@ def get_audio_info(video_path: Path | str) -> dict:
return audio_info
def get_video_info(
video_path: Path | str,
camera_encoder: VideoEncoderConfig | None = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
camera_encoder: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence — encoder fields are only written for keys
not already populated from the video file itself.
"""
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting video stream information
@@ -1100,14 +1025,6 @@ def get_video_info(
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided
if camera_encoder is not None:
for field_name, field_value in asdict(camera_encoder).items():
# vcodec is already populated from the video stream
if field_name == "vcodec":
continue
video_info.setdefault(f"video.{field_name}", field_value)
return video_info

View File

@@ -18,25 +18,12 @@ from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import require_package
from lerobot.utils.import_utils import _placo_available, require_package
_placo_runtime_error: ImportError | None = None
if TYPE_CHECKING:
if TYPE_CHECKING or _placo_available:
import placo # type: ignore[import-not-found]
else:
try:
import placo # type: ignore[import-not-found]
except ImportError as _placo_import_err:
placo = None
_placo_runtime_error = _placo_import_err
def _raise_if_placo_unusable() -> None:
if placo is None and _placo_runtime_error is not None:
raise ImportError(
f"placo is installed but failed to import: {_placo_runtime_error!s}"
) from _placo_runtime_error
placo = None
class RobotKinematics:
@@ -57,7 +44,6 @@ class RobotKinematics:
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
require_package("placo", extra="placo-dep")
_raise_if_placo_unusable()
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)

View File

@@ -43,7 +43,6 @@ from .tables import (
CAN_CMD_SET_ZERO,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
HANDSHAKE_TIMEOUT_S,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
@@ -216,16 +215,14 @@ class RobstrideMotorsBus(MotorsBusBase):
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _query_status_via_clear_fault(
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
) -> tuple[bool, can.Message | None]:
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
motor_name = self._get_motor_name(motor)
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
def _recv_status_via_clear_fault(
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
@@ -283,7 +280,7 @@ class RobstrideMotorsBus(MotorsBusBase):
faulted_motors = []
for motor_name in self.motors:
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
has_fault, msg = self._query_status_via_clear_fault(motor_name)
if msg is None:
missing_motors.append(motor_name)
elif has_fault:
@@ -508,87 +505,6 @@ class RobstrideMotorsBus(MotorsBusBase):
return responses
def _recv_all_messages_until_quiet(
self,
*,
timeout: float = RUNNING_TIMEOUT,
max_messages: int = 4096,
) -> list[can.Message]:
"""
Receive frames until the bus goes quiet.
Args:
timeout: Poll timeout used for each recv() call. Collection stops
when one recv() times out (quiet gap).
max_messages: Safety cap to prevent unbounded loops.
"""
out: list[can.Message] = []
max_messages = max(1, max_messages)
timeout = max(0.0, timeout)
try:
while len(out) < max_messages:
msg = self._bus().recv(timeout=timeout)
if msg is None:
break
out.append(msg)
except (can.CanError, OSError) as e:
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
return out
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
"""
Decode all received feedback frames and update cached motor states.
Returns:
Set of payload recv_ids that were successfully mapped to motors.
"""
processed_recv_ids: set[int] = set()
for msg in messages:
if len(msg.data) < 1:
logger.debug(
f"Dropping short CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
)
continue
recv_id = int(msg.data[0])
motor_name = self._recv_id_to_motor.get(recv_id)
if motor_name is None:
logger.debug(
f"Unmapped CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
)
continue
self._process_response(motor_name, msg)
processed_recv_ids.add(recv_id)
return processed_recv_ids
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
"""
Drain pending RX frames from the CAN interface.
This is used by higher-level controllers to drop stale feedback before issuing
a fresh read cycle, so subsequent state reads are based on most recent replies.
It should also be called once when a controller instance is created/connected,
to clear residual frames left on the interface from previous sessions.
"""
drained = 0
poll_timeout_s = max(0.0, poll_timeout_s)
max_messages = max(1, max_messages)
try:
while drained < max_messages:
msg = self._bus().recv(timeout=poll_timeout_s)
if msg is None:
break
drained += 1
except (can.CanError, OSError) as e:
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
return drained
def _speed_control(
self,
motor: NameOrID,
@@ -728,14 +644,11 @@ class RobstrideMotorsBus(MotorsBusBase):
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Read every feedback frame until RX goes quiet, then decode all of them.
# This avoids dropping useful frames when responses from different motors interleave.
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
for recv_id, motor_name in recv_id_to_motor.items():
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
@@ -798,10 +711,7 @@ class RobstrideMotorsBus(MotorsBusBase):
try:
self._decode_motor_state(msg.data)
except Exception as e:
logger.warning(
f"Failed to decode response from {motor} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
)
logger.warning(f"Failed to decode response from {motor}: {e}")
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the state cache."""
@@ -938,12 +848,20 @@ class RobstrideMotorsBus(MotorsBusBase):
self._bus().send(msg)
updated_motors.append(motor)
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
for response in responses.values():
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
if payload_motor_name is not None:
self._process_response(payload_motor_name, response)
else:
# Fallback: still attempt to decode based on payload byte0 mapping.
self._decode_motor_state(response.data)
for motor in updated_motors:
recv_id = self._get_motor_recv_id(motor)
if recv_id not in processed_recv_ids:
if recv_id not in responses:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def read_calibration(self) -> dict[str, MotorCalibration]:

View File

@@ -114,8 +114,7 @@ CAN_CMD_SAVE_PARAM = 0xAA
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.003
HANDSHAKE_TIMEOUT_S = 0.05
RUNNING_TIMEOUT = 0.001
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02

View File

@@ -20,7 +20,6 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
@@ -44,7 +43,6 @@ __all__ = [
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
"PI0Config",
"PI0FastConfig",

View File

@@ -28,12 +28,11 @@ import torch.nn.functional as F # noqa: N812
import torch.utils.checkpoint
from torch import Tensor
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
from ..pretrained import PreTrainedPolicy
from .configuration_eo1 import EO1Config
if TYPE_CHECKING or _transformers_available:
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration

View File

@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
@@ -43,8 +44,6 @@ from lerobot.utils.constants import (
)
from lerobot.utils.import_utils import _transformers_available, require_package
from .configuration_eo1 import EO1Config
if TYPE_CHECKING or _transformers_available:
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
else:

View File

@@ -49,7 +49,6 @@ from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
@@ -57,7 +56,6 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -90,8 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
"molmoact2".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -154,14 +151,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "molmoact2":
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -179,7 +168,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x", "molmoact2".
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -214,10 +203,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -246,7 +231,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
dataset_meta: Any | None
def make_pre_post_processors(
@@ -422,7 +406,6 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -431,23 +414,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MolmoAct2Config):
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
processors = make_molmoact2_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
@@ -533,10 +499,6 @@ def make_policy(
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
if ds_meta is not None:
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
if callable(set_dataset_feature_metadata):
set_dataset_feature_metadata(ds_meta.features)
kwargs["config"] = cfg

View File

@@ -60,7 +60,6 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True

View File

@@ -124,6 +124,7 @@ class Eagle25VLProcessor(ProcessorMixin):
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(

View File

@@ -14,7 +14,7 @@
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import numpy as np
import torch
@@ -26,14 +26,9 @@ from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
@@ -178,20 +173,19 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
backbone_cfg: dict
action_head_cfg: dict
action_horizon: int
action_dim: int
compute_dtype: str = "float32"
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
# real model

View File

@@ -206,11 +206,7 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(
str(cache_dir),
trust_remote_code=True,
fix_mistral_regex=False,
)
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
proc.tokenizer.padding_side = "left"
return proc

View File

@@ -1 +0,0 @@
../../../../docs/source/policy_molmoact2_README.md

View File

@@ -1,21 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_molmoact2 import MolmoAct2Config
from .modeling_molmoact2 import MolmoAct2Policy
from .processor_molmoact2 import make_molmoact2_pre_post_processors
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]

View File

@@ -1,519 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
CosineDecayWithWarmupSchedulerConfig,
LRSchedulerConfig,
OptimizerConfig,
)
from lerobot.utils.constants import ACTION, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
training steps"; build() is the first point where num_training_steps is known.
"""
num_decay_steps: int | None
def build(self, optimizer, num_training_steps: int):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.peak_lr,
decay_lr=self.decay_lr,
num_warmup_steps=self.num_warmup_steps,
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
).build(optimizer, num_training_steps=num_training_steps)
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
@PreTrainedConfig.register_subclass("molmoact2")
@dataclass
class MolmoAct2Config(PreTrainedConfig):
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
checkpoint_path: str = "allenai/MolmoAct2"
checkpoint_revision: str | None = None
checkpoint_force_download: bool = False
n_obs_steps: int = 1
chunk_size: int = 30
n_action_steps: int = 30
action_mode: str = "both"
inference_action_mode: str | None = None
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
discrete_generation_max_steps: int | None = None
norm_tag: str | None = None
setup_type: str = ""
control_mode: str = ""
image_keys: list[str] = field(default_factory=list)
normalize_language: bool = True
add_setup_tokens: bool = True
add_control_tokens: bool = True
normalize_gripper: bool = False
num_state_tokens: int = 256
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
# image/prompt/state/action token layout. Override only for unusual long prompts.
max_sequence_length: int | None = None
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
expected_max_action_dim: int = 32
# Flow-matching training knobs copied from the original MolmoAct2 training path.
num_flow_timesteps: int = 8
flow_matching_cutoff: float = 1.0
flow_matching_time_offset: float = 0.001
flow_matching_time_scale: float = 0.999
flow_matching_beta_alpha: float = 1.0
flow_matching_beta_beta: float = 1.5
num_inference_steps: int | None = None
mask_action_dim_padding: bool = True
enable_inference_cuda_graph: bool = True
# MolmoAct2-local eval option. When enabled, stochastic continuous action
# generation uses a rollout-local generator derived from eval_seed.
per_episode_seed: bool = False
eval_seed: int | None = None
rtc_config: RTCConfig | None = None
# Default is full finetuning with gradients from the action expert flowing into the VLM.
enable_lora_vlm: bool = False
lora_rank: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_bias: str = "none"
enable_lora_action_expert: bool = False
enable_knowledge_insulation: bool = False
freeze_embedding: bool = True
train_action_expert_only: bool = False
gradient_checkpointing: bool = False
model_dtype: str = "bfloat16"
softmax_auxiliary_loss: bool = True
softmax_auxiliary_loss_scale: float = 1e-4
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
optimizer_lr: float = 1e-5
optimizer_vit_lr: float = 5e-6
optimizer_connector_lr: float = 5e-6
optimizer_action_expert_lr: float = 5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-6
optimizer_weight_decay: float = 0.0
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 200
scheduler_decay_steps: int | None = None
scheduler_decay_lr: float = 1e-6
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES,
"ACTION": NormalizationMode.QUANTILES,
}
)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
super().__post_init__()
if self.action_mode not in {"continuous", "discrete", "both"}:
raise ValueError(
f"Unsupported action_mode={self.action_mode!r}. "
"Expected one of {'continuous', 'discrete', 'both'}."
)
if self.inference_action_mode not in {None, "continuous", "discrete"}:
raise ValueError(
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
"Expected one of {None, 'continuous', 'discrete'}."
)
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
if self.train_action_expert_only and self.action_mode != "continuous":
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
if self.train_action_expert_only and self.enable_lora_vlm:
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
if self.enable_lora_action_expert and not self.enable_lora_vlm:
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
if self.chunk_size < 1:
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
if self.n_action_steps < 1:
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
)
if self.expected_max_action_dim != 32:
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
raise ValueError(
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
)
if self.lora_rank < 1:
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
if self.lora_alpha < 1:
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
if not 0 <= self.lora_dropout <= 1:
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
if self.lora_bias not in {"none", "all", "lora_only"}:
raise ValueError(
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
)
if self.discrete_loss_token_weighting not in {
"none",
"token",
"root_tokens",
"root_subsegments",
"root_subsegments_root_tokens",
}:
raise ValueError(
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
)
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
raise ValueError(
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
)
if self.max_sequence_length is not None and self.max_sequence_length < 1:
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
def inferred_max_sequence_length(
self,
*,
num_images: int | None = None,
state_dim: int | None = None,
action_dim: int | None = None,
action_horizon: int | None = None,
include_discrete_action: bool | None = None,
) -> int:
if self.max_sequence_length is not None:
return int(self.max_sequence_length)
if num_images is None:
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim is None:
state_feature = self.robot_state_feature
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
if action_dim is None:
action_feature = self.action_feature
action_dim = (
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
)
if action_horizon is None:
action_horizon = self.chunk_size
if include_discrete_action is None:
include_discrete_action = self.action_mode in {"discrete", "both"}
return infer_molmoact2_max_sequence_length(
num_images=int(num_images),
state_dim=int(state_dim),
action_dim=int(action_dim),
action_horizon=int(action_horizon),
include_discrete_action=bool(include_discrete_action),
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
self.dataset_feature_names = {}
for key in (ACTION, OBS_STATE):
feature = features.get(key) if isinstance(features, dict) else None
if isinstance(feature, dict) and feature.get("names") is not None:
self.dataset_feature_names[key] = feature["names"]
def validate_features(self) -> None:
"""Validate and set up MolmoAct2 input and output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"MolmoAct2 policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(0,),
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode

View File

@@ -1,17 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa

View File

@@ -1,237 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
import logging
import os
from pathlib import Path
from typing import ClassVar
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_tokenizer_location(
tokenizer_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
local_path = Path(str(tokenizer_path)).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=str(tokenizer_path),
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
class UniversalActionProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["tokenizer"]
tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(tokenizer)
self.bpe_tokenizer = self.tokenizer
def __call__(self, action_chunk: np.array) -> np.array:
from scipy.fft import dct
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
from scipy.fft import idct
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.array],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "UniversalActionProcessor":
from scipy.fft import dct
# Run DCT over all inputs
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert min_vocab_size <= vocab_size, (
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
)
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rounded_tokens = np.around(tokens * scale) - min_token
rounded_tokens = rounded_tokens.astype(int)
string = "".join(map(chr, rounded_tokens))
yield string
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)
@classmethod
def from_pretrained_local(
cls,
pretrained_model_name_or_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> "UniversalActionProcessor":
location = Path(
_resolve_tokenizer_location(
pretrained_model_name_or_path,
revision=revision,
force_download=force_download,
)
)
processor_config = {}
processor_config_path = location / "processor_config.json"
if processor_config_path.exists():
import json
processor_config = json.loads(processor_config_path.read_text())
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
return cls(
tokenizer,
scale=processor_config.get("scale", 10),
vocab_size=processor_config.get("vocab_size", 1024),
min_token=processor_config.get("min_token", 0),
action_dim=processor_config.get("action_dim"),
time_horizon=processor_config.get("time_horizon"),
)

View File

@@ -1,553 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
MolmoAct2 configuration
"""
from typing import Optional, Any
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MolmoAct2VitConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
>>> # Initializing a MolmoAct2VitConfig
>>> configuration = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
>>> model = MolmoAct2VisionTransformer(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
base_config_key = "vit_config"
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
num_hidden_layers: int = 27,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
hidden_act: str = "gelu_pytorch_tanh",
layer_norm_eps: float = 1e-6,
image_default_input_size: tuple[int, int] = (378, 378),
image_patch_size: int = 14,
image_num_pos: int = 577,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
initializer_range: float = 0.02,
float32_attention: bool = True,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.image_default_input_size = image_default_input_size
self.image_patch_size = image_patch_size
self.image_num_pos = image_num_pos
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.initializer_range = initializer_range
self.float32_attention = float32_attention
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
class MolmoAct2AdapterConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
>>> vit_config = MolmoAct2VitConfig()
>>> adapter_config = MolmoPoolingConfig()
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
>>> # Accessing the model configuration
>>> vit_configuration = model.vit_config
>>> adapter_configuration = model.adapter_config
```"""
model_type = "molmoact2"
base_config_key = "adapter_config"
def __init__(
self,
vit_layers: tuple = (-3, -9),
pooling_attention_mask: bool = False,
hidden_size: int = 1152,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
float32_attention: bool = True,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
hidden_act: str = "silu",
intermediate_size: int = 18944,
text_hidden_size: int = 3584,
image_feature_dropout: float = 0.0,
initializer_range: float = 0.02,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.vit_layers = vit_layers
self.pooling_attention_mask = pooling_attention_mask
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.float32_attention = float32_attention
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.text_hidden_size = text_hidden_size
self.image_feature_dropout = image_feature_dropout
self.initializer_range = initializer_range
class MolmoAct2TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
>>> # Initializing a MolmoAct2TextConfig
>>> configuration = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2TextModel (with random weights)
>>> model = MolmoAct2TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"blocks.*.self_attn.att_proj": "colwise",
"blocks.*.self_attn.attn_out": "rowwise",
"blocks.*.mlp.ff_proj": "colwise",
"blocks.*.mlp.ff_out": "rowwise",
}
base_model_pp_plan = {
"wte": (["input_ids"], ["inputs_embeds"]),
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
"ln_f": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
hidden_size: int = 3584,
num_attention_heads: int = 28,
num_key_value_heads: int | None = 4,
head_dim: int = 128,
vocab_size: int = 152064,
additional_vocab_size: int = 128,
qkv_bias: bool = True,
num_hidden_layers: int = 48,
intermediate_size: int = 18944,
hidden_act: str = "silu",
embedding_dropout: float = 0.0,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
max_position_embeddings: int = 4096,
rope_theta: float = 1000000.0,
rope_scaling: dict[str, Any] = None,
rope_scaling_layers: list[int] | None = None,
use_qk_norm: bool = False,
qk_norm_type: str = "olmo",
layer_norm_eps: int = 1e-6,
norm_after: bool = False,
initializer_range: float = 0.02,
use_cache=True,
tie_word_embeddings=False,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.vocab_size = vocab_size
self.additional_vocab_size = additional_vocab_size
self.qkv_bias = qkv_bias
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.embedding_dropout = embedding_dropout
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_scaling_layers = rope_scaling_layers
self.use_qk_norm = use_qk_norm
self.qk_norm_type = qk_norm_type
self.layer_norm_eps = layer_norm_eps
self.norm_after = norm_after
self.initializer_range = initializer_range
self.use_cache = use_cache
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
class MolmoAct2ActionExpertConfig(PretrainedConfig):
r"""Configuration for the MolmoAct2 modern action expert."""
model_type = "molmoact2_action_expert"
base_config_key = "action_expert_config"
def __init__(
self,
max_action_horizon: int = 32,
max_action_dim: int = 32,
hidden_size: int = 1024,
num_layers: int = 32,
num_heads: int = 16,
mlp_ratio: float = 8.0 / 3.0,
ffn_multiple_of: int = 256,
timestep_embed_dim: int = 256,
dropout: float = 0.0,
attn_dropout: float = 0.0,
context_layer_norm: bool = True,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
rope: bool = True,
causal_attn: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.max_action_horizon = max_action_horizon
self.max_action_dim = max_action_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.ffn_multiple_of = ffn_multiple_of
self.timestep_embed_dim = timestep_embed_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.context_layer_norm = context_layer_norm
self.qk_norm = qk_norm
self.qk_norm_eps = qk_norm_eps
self.rope = rope
self.causal_attn = causal_attn
def to_dict(self):
output = super().to_dict()
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
# them out of the public nested config avoids duplicated sources of truth.
output.pop("max_action_horizon", None)
output.pop("max_action_dim", None)
return output
class MolmoAct2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
Example:
```python
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
>>> # Initializing a MolmoAct2VitConfig
>>> vit_config = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2AdapterConfig
>>> adapter_config = MolmoAct2AdapterConfig()
>>> # Initializing a MolmoAct2TextConfig
>>> text_config = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2Config
>>> configuration = MolmoAct2Config(
>>> vit_config=vit_config,
>>> adapter_config=adapter_config,
>>> text_config=text_config,
>>> image_start_token_id=151936,
>>> image_end_token_id=151937,
>>> image_patch_id=151938,
>>> image_col_id=151939,
>>> low_res_image_start_token_id=151940,
>>> image_low_res_id=151942,
>>> frame_start_token_id=151943,
>>> frame_end_token_id=151944,
>>> )
>>> # Initializing a model
>>> model = MolmoAct2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
sub_configs = {
"text_config": MolmoAct2TextConfig,
"vit_config": MolmoAct2VitConfig,
"adapter_config": MolmoAct2AdapterConfig,
"action_expert_config": MolmoAct2ActionExpertConfig,
}
def __init__(
self,
vit_config: MolmoAct2VitConfig = None,
adapter_config: MolmoAct2AdapterConfig = None,
text_config: MolmoAct2TextConfig = None,
action_expert_config: MolmoAct2ActionExpertConfig = None,
image_start_token_id: int = None,
low_res_image_start_token_id: int = None,
image_end_token_id: int = None,
image_low_res_id: int = None,
image_patch_id: int = None,
image_col_id: int = None,
frame_start_token_id: int = None,
frame_end_token_id: int = None,
use_frame_special_tokens: bool = True,
initializer_range: float = 0.02,
add_action_expert: bool = True,
max_action_dim: int = 32,
max_action_horizon: int = 30,
n_obs_steps: int = 30,
action_mode: str = "both",
state_format: str = "discrete",
flow_matching_num_steps: int = 10,
flow_matching_cutoff: float = 1.0,
flow_matching_time_offset: float = 0.001,
flow_matching_time_scale: float = 0.999,
flow_matching_beta_alpha: float = 1.0,
flow_matching_beta_beta: float = 1.5,
mask_action_dim_padding: bool = True,
enable_depth_reasoning: bool = False,
depth_mode: int = 2,
num_depth_codes: int = 100,
action_expert_depth_gate: bool = False,
action_expert_depth_gate_per_layer: bool = False,
action_expert_depth_gate_init_bias: float = -4.0,
action_output_token_id: int = None,
action_start_token_id: int = None,
action_end_token_id: int = None,
action_token_start_id: int = None,
num_action_tokens: int = 0,
depth_output_token_id: int = None,
depth_start_token_id: int = None,
depth_end_token_id: int = None,
depth_token_start_id: int = None,
num_depth_tokens: int = 0,
state_start_token_id: int = None,
state_end_token_id: int = None,
state_token_start_id: int = None,
num_state_tokens: int = 0,
add_setup_tokens: bool = True,
add_control_tokens: bool = True,
norm_stats_filename: str = "norm_stats.json",
**kwargs,
):
super().__init__(**kwargs)
if vit_config is None:
self.vit_config = MolmoAct2VitConfig()
elif isinstance(vit_config, dict):
self.vit_config = MolmoAct2VitConfig(**vit_config)
else:
self.vit_config = vit_config
if adapter_config is None:
self.adapter_config = MolmoAct2AdapterConfig()
elif isinstance(adapter_config, dict):
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
else:
self.adapter_config = adapter_config
if text_config is None:
self.text_config = MolmoAct2TextConfig()
elif isinstance(text_config, dict):
self.text_config = MolmoAct2TextConfig(**text_config)
else:
self.text_config = text_config
self.add_action_expert = bool(add_action_expert)
if not self.add_action_expert:
self.action_expert_config = None
elif action_expert_config is None:
self.action_expert_config = MolmoAct2ActionExpertConfig(
max_action_horizon=max_action_horizon,
max_action_dim=max_action_dim,
num_layers=self.text_config.num_hidden_layers,
)
elif isinstance(action_expert_config, dict):
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
else:
self.action_expert_config = action_expert_config
if self.add_action_expert:
self.action_expert_config.max_action_dim = int(max_action_dim)
self.action_expert_config.max_action_horizon = int(max_action_horizon)
self._validate_release_action_config(
state_format=state_format,
)
self.image_start_token_id = image_start_token_id
self.low_res_image_start_token_id = low_res_image_start_token_id
self.image_end_token_id = image_end_token_id
self.image_low_res_id = image_low_res_id
self.image_high_res_id = image_patch_id
self.image_patch_id = image_patch_id
self.image_col_id = image_col_id
self.frame_start_token_id = frame_start_token_id
self.frame_end_token_id = frame_end_token_id
self.use_frame_special_tokens = use_frame_special_tokens
self.initializer_range = initializer_range
self.max_action_dim = max_action_dim
self.max_action_horizon = max_action_horizon
self.n_obs_steps = n_obs_steps
self.action_mode = action_mode
self.state_format = state_format
self.flow_matching_num_steps = flow_matching_num_steps
self.flow_matching_cutoff = flow_matching_cutoff
self.flow_matching_time_offset = flow_matching_time_offset
self.flow_matching_time_scale = flow_matching_time_scale
self.flow_matching_beta_alpha = flow_matching_beta_alpha
self.flow_matching_beta_beta = flow_matching_beta_beta
self.mask_action_dim_padding = mask_action_dim_padding
self.enable_depth_reasoning = enable_depth_reasoning
self.depth_mode = depth_mode
self.num_depth_codes = num_depth_codes
self.action_expert_depth_gate = action_expert_depth_gate
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
self.action_output_token_id = action_output_token_id
self.action_start_token_id = action_start_token_id
self.action_end_token_id = action_end_token_id
self.action_token_start_id = action_token_start_id
self.num_action_tokens = num_action_tokens
self.depth_output_token_id = depth_output_token_id
self.depth_start_token_id = depth_start_token_id
self.depth_end_token_id = depth_end_token_id
self.depth_token_start_id = depth_token_start_id
self.num_depth_tokens = num_depth_tokens
self.state_start_token_id = state_start_token_id
self.state_end_token_id = state_end_token_id
self.state_token_start_id = state_token_start_id
self.num_state_tokens = num_state_tokens
self.add_setup_tokens = add_setup_tokens
self.add_control_tokens = add_control_tokens
self.norm_stats_filename = norm_stats_filename
@staticmethod
def _validate_release_action_config(
*,
state_format: str,
) -> None:
if state_format != "discrete":
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
@property
def image_num_patch(self):
assert self.vit_config is not None
return self.vit_config.image_num_patch
@property
def num_attention_heads(self):
return self.text_config.num_attention_heads
@property
def num_key_value_heads(self):
return self.text_config.num_key_value_heads
@property
def head_dim(self):
return self.text_config.head_dim
@property
def num_hidden_layers(self):
return self.text_config.num_hidden_layers
@property
def hidden_size(self):
return self.text_config.hidden_size
@property
def vocab_size(self):
return self.text_config.vocab_size
@property
def max_position_embeddings(self):
return self.text_config.max_position_embeddings
MolmoAct2VitConfig.register_for_auto_class()
MolmoAct2AdapterConfig.register_for_auto_class()
MolmoAct2TextConfig.register_for_auto_class()
MolmoAct2ActionExpertConfig.register_for_auto_class()
MolmoAct2Config.register_for_auto_class()

View File

@@ -1,564 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
valid_images,
to_numpy_array,
)
from transformers.image_transforms import convert_to_rgb
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def select_tiling(h, w, patch_size, max_num_crops):
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
if i * j <= max_num_crops:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
# The original size can be zero in rare cases if the image is smaller than the margin
# In those cases letting the scale become infinite means the tiling is based on the
# other side, or falls back to the smallest tiling
with np.errstate(divide="ignore"):
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def build_overlapping_crops(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Decompose an image into a set of overlapping crops
:return crop_arr: [n_crops, h, w, 3] The crops
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
the crops were extracted from, what patch in `crop_arr` it corresponds to
"""
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
assert base_image_input_size[0] == base_image_input_size[1]
left_margin, right_margin = overlap_margins
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * image_patch_size
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Decide how to tile the image, to account for the overlap margins we compute the tiling
# as if we had an image without the margins and were using a crop size without the margins
tiling = select_tiling(
original_image_h - total_margin_pixels,
original_image_w - total_margin_pixels,
crop_window_size,
max_crops,
)
src = resize_image(
image,
[
tiling[0] * crop_window_size + total_margin_pixels,
tiling[1] * crop_window_size + total_margin_pixels,
],
resample,
)
src = normalize_image(src, image_mean, image_std)
# Now we have to split the image into crops, and track what patches came from
# where in `patch_idx_arr`
n_crops = tiling[0] * tiling[1]
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
on_crop = 0
for i in range(tiling[0]):
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
# which results in overlapping crop windows
y0 = i * crop_window_size
for j in range(tiling[1]):
x0 = j * crop_window_size
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
patch_idx += on_crop * crop_patch_h * crop_patch_w
# Mask out idx that are in the overlap region
if i != 0:
patch_idx[:left_margin, :] = -1
if j != 0:
patch_idx[:, :left_margin] = -1
if i != tiling[0] - 1:
patch_idx[-right_margin:, :] = -1
if j != tiling[1] - 1:
patch_idx[:, -right_margin:] = -1
patch_idx_arr[on_crop] = patch_idx
on_crop += 1
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
# so it is ordered left-to-right order
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
# Now get the parts not in the overlap region, so it should map each patch in `src`
# to the correct patch it should come from in `crop_arr`
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
src.shape[0] // image_patch_size,
src.shape[1] // image_patch_size,
)
return crop_arr, patch_idx_arr
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
crop_mode: str = "overlap-and-resize-c2",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each (low-res, high-res) image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
pooling_w = image_pooling_w
pooling_h = image_pooling_h
crop_patch_w = base_image_input_size[1] // base_image_input_d
crop_patch_h = base_image_input_size[0] // base_image_input_d
if crop_mode == "resize":
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [np.array([resized_h, resized_w, 0, 0])]
return (
np.stack(image_grid, 0),
batch_pixels_to_patches(resized, image_patch_size),
resize_idx,
)
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
crop_arr, patch_idx_arr = build_overlapping_crops(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
# Finally do the same for the global image
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
crop_arr = np.concatenate([resized, crop_arr], 0)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
# Global image goes first, so the order of patches in previous crops gets increased
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
pooling_idx = np.concatenate([resize_idx, pooling_idx])
image_grid = [np.array([resized_h, resized_w, h, w])]
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
max_crops: int | None
overlap_margins: list[int] | None
crop_mode: str | None
patch_size: int | None
pooling_size: list[int] | None
class MolmoAct2ImageProcessor(BaseImageProcessor):
r"""
Constructs a MolmoAct2 image processor that preprocesses images for the model.
Args:
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use when resizing the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `8`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to 14):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
The pooling size of the vision adapter.
"""
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
def __init__(
self,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
crop_mode: str = "overlap-and-resize-c2",
patch_size: int = 14,
pooling_size: list[int] = [2, 2],
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 378, "width": 378}
size = get_size_dict(size, default_to_square=True)
self.size = size
self.resample = resample
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.crop_mode = crop_mode
self.patch_size = patch_size
self.pooling_size = pooling_size
def preprocess(
self,
images: ImageInput,
size: dict[str, int] | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
max_crops: int | None = None,
overlap_margins: list[int] | None = None,
crop_mode: str | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Args:
images (`ImageInput`):
Image to preprocess.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `self.max_crops`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values`: The preprocessed images.
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
- `image_grids`: The image grids.
- `image_num_crops`: The number of crops for each image.
"""
if size is not None:
if "height" not in size or "width" not in size:
raise ValueError("size must contain 'height' and 'width' keys.")
else:
size = {**self.size}
base_image_input_size = [size["height"], size["width"]]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
crop_mode = crop_mode or self.crop_mode
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
if images is not None:
images = self.fetch_images(images)
images = make_flat_list_of_images(images)
if images is not None and not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
data = {}
if images is not None:
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
batch_num_crops = []
for image in images:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
crop_mode,
)
batch_grids.append(image_grid)
batch_crops.append(crops)
batch_pooled_patches_idx.append(pooled_idx)
batch_num_crops.append(crops.shape[0])
pixel_values = np.concatenate(batch_crops, 0)
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
image_grids = np.concatenate(batch_grids, 0)
image_num_crops = np.array(batch_num_crops)
data.update(
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2ImageProcessor.register_for_auto_class()

View File

@@ -1,748 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
import torch
from torch.nn import functional as F
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
@dataclass
class _ActionFlowInputs:
trajectory: torch.Tensor
context: Any
modulations: Sequence[Any]
action_dim_is_pad: torch.Tensor | None
@dataclass
class _ActionFlowCudaGraph:
key: tuple[Any, ...]
graph: torch.cuda.CUDAGraph
static_inputs: _ActionFlowInputs
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphLayerStage:
residual: torch.Tensor
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphPostStage:
graph: torch.cuda.CUDAGraph
attn_context: torch.Tensor
@dataclass
class _DepthDecodeCudaGraph:
cache_key: tuple[Any, ...]
pre_graph: torch.cuda.CUDAGraph
token_ids: torch.Tensor
cos: torch.Tensor
sin: torch.Tensor
positions: torch.Tensor
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphSpec:
eligible: bool
cache_key_prefix: tuple[Any, ...]
num_hidden_layers: int
head_dim: int
num_attention_heads: int
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return 0
seq_len = past_key_values.get_seq_length()
if torch.is_tensor(seq_len):
return int(seq_len.item())
return int(seq_len)
def _cache_max_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return -1
max_len = past_key_values.get_max_cache_shape()
if torch.is_tensor(max_len):
return int(max_len.item())
return int(max_len)
def _iter_cache_key_values(
past_key_values: Cache,
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
layers = getattr(past_key_values, "layers", None)
if layers is not None:
for layer in layers:
yield getattr(layer, "keys", None), getattr(layer, "values", None)
return
for layer in past_key_values:
yield layer[0], layer[1]
class _DepthDecodeStaticLayerCache:
is_compileable = False
is_sliding = False
def __init__(self, max_cache_len: int) -> None:
self.max_cache_len = int(max_cache_len)
self.cumulative_length = 0
self.keys: torch.Tensor | None = None
self.values: torch.Tensor | None = None
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
bsz, n_heads = key_states.shape[:2]
self.keys = torch.empty(
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
dtype=key_states.dtype,
device=key_states.device,
)
self.values = torch.empty(
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
dtype=value_states.dtype,
device=value_states.device,
)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
*args,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.keys is None:
self._allocate(key_states, value_states)
start = self.cumulative_length
end = start + key_states.shape[-2]
if end > self.max_cache_len:
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
self.keys[:, :, start:end, :].copy_(key_states)
self.values[:, :, start:end, :].copy_(value_states)
self.cumulative_length = end
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
def get_seq_length(self) -> int:
return self.cumulative_length
def get_max_cache_shape(self) -> int:
return -1
def reset(self) -> None:
self.cumulative_length = 0
class _DepthDecodeStaticCache(Cache):
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
text_config = config.get_text_config(decoder=True)
super().__init__(
layers=[
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
for _ in range(text_config.num_hidden_layers)
]
)
def get_seq_length(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_max_cache_shape()
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class ActionCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.enabled = True
self.action_flow_graph: _ActionFlowCudaGraph | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
action_model = self.model
if not self.enabled:
return False
if action_model.training or action_model._require_action_expert().training:
return False
if inputs.trajectory.device.type != "cuda":
return False
def all_on_cuda():
yield inputs.trajectory
for k, v in inputs.context.kv_contexts:
yield k
yield v
for t in (
inputs.context.cross_mask,
inputs.context.self_mask,
inputs.context.valid_action,
inputs.action_dim_is_pad,
):
if t is not None:
yield t
if inputs.context.rope_cache is not None:
yield from inputs.context.rope_cache
for step in inputs.modulations:
yield step.conditioning
for block_modulation in step.block_modulations:
yield from block_modulation
yield from step.final_modulation
return all(t.device.type == "cuda" for t in all_on_cuda())
def run_action_flow(
self,
inputs: _ActionFlowInputs,
steps: int,
run_loop,
) -> torch.Tensor:
key = _cuda_graph_key(inputs, steps)
cache = self.action_flow_graph
if cache is None or cache.key != key:
static_inputs = _clone_static_inputs(inputs)
graph, output = _capture_cuda_graph(
lambda: run_loop(static_inputs, steps),
inputs.trajectory.device,
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
)
cache = _ActionFlowCudaGraph(
key=key,
graph=graph,
static_inputs=static_inputs,
output=output,
)
self.action_flow_graph = cache
else:
_copy_inputs_(cache.static_inputs, inputs)
cache.graph.replay()
return cache.output.clone()
class DepthDecodeCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.backbone = model.model
self.enabled = True
self.graph: _DepthDecodeCudaGraph | None = None
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
return _DepthDecodeStaticCache(
config=self.model.config.text_config,
max_cache_len=max_cache_len,
)
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
static = self.graph_spec
if static is None:
cfg = self.backbone.transformer.config
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
static = _DepthDecodeCudaGraphSpec(
eligible=(
not cfg.norm_after
and cfg.rope_scaling_layers is None
and getattr(rotary_emb, "rope_type", None) == "default"
and cfg._attn_implementation == "sdpa"
),
cache_key_prefix=(
cfg.hidden_size,
cfg.num_attention_heads,
cfg.num_key_value_heads,
cfg.head_dim,
cfg.num_hidden_layers,
cfg.use_qk_norm,
cfg.qk_norm_type,
cfg._attn_implementation,
),
num_hidden_layers=cfg.num_hidden_layers,
head_dim=cfg.head_dim,
num_attention_heads=cfg.num_attention_heads,
)
self.graph_spec = static
return static
def can_use(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
) -> bool:
if not self.enabled or self.model.training or self.backbone.transformer.training:
return False
if next_input_ids.device.type != "cuda":
return False
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
return False
if not isinstance(past_key_values, _DepthDecodeStaticCache):
return False
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
return False
return self._depth_decode_spec().eligible
def _depth_decode_key(
self,
next_input_ids: torch.Tensor,
attention_bias: torch.Tensor,
) -> tuple[Any, ...]:
device = next_input_ids.device
return (
self._depth_decode_spec().cache_key_prefix,
device.type,
device.index,
self.model.lm_head.weight.dtype,
attention_bias.shape[-1],
)
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
emb = self.backbone.transformer.rotary_emb
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
def _depth_decode_pre_layer(
self,
layer_idx: int,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
residual = hidden_states
hidden_states = block.attn_norm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attention.head_dim)
qkv = attention.att_proj(hidden_states)
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
value_states = value_states.view(hidden_shape)
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
if apply_qk_norm and not norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
if norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
return residual, query_states, key_states, value_states
def _depth_decode_pre0(
self,
token_ids: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
inputs_embeds = self.model._embed_base_tokens(token_ids)
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
def _depth_decode_post_layer(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
input_shape = residual.shape[:-1]
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
attn_output = attention.attn_out(attn_output)
hidden_states = residual + block.dropout(attn_output)
residual = hidden_states
hidden_states = block.ff_norm(hidden_states)
hidden_states = block.mlp(hidden_states)
hidden_states = residual + block.dropout(hidden_states)
return hidden_states
def _depth_decode_post_and_pre_next(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
def _depth_decode_last_post(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self.backbone.transformer.ln_f(hidden_states)
def _build_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
text_config = self.backbone.transformer.config
device = next_input_ids.device
dtype = self.model.lm_head.weight.dtype
static = self._depth_decode_spec()
num_layers = static.num_hidden_layers
head_dim = static.head_dim
max_cache_len = int(attention_bias.shape[-1])
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
sin = torch.empty_like(cos)
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
context_shape = (1, 1, static.num_attention_heads, head_dim)
token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(cos, sin, past_length=past_length)
pre_graph, pre_output = _capture_cuda_graph(
lambda: self._depth_decode_pre0(token_ids, cos, sin),
device,
)
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
post_graphs = []
for layer_idx in range(num_layers - 1):
stage = stages[-1]
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
graph, output = _capture_cuda_graph(
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
self._depth_decode_post_and_pre_next(
layer_idx,
stage.residual,
attn_context,
cos,
sin,
)
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
last_stage = stages[-1]
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
last_graph, last_output = _capture_cuda_graph(
lambda: self._depth_decode_last_post(
num_layers - 1,
last_stage.residual,
last_attn_context,
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
return _DepthDecodeCudaGraph(
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
pre_graph=pre_graph,
token_ids=token_ids,
cos=cos,
sin=sin,
positions=positions,
stages=tuple(stages),
post_graphs=tuple(post_graphs),
output=last_output,
)
def _get_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
key = self._depth_decode_key(next_input_ids, attention_bias)
decode_graph = self.graph
if decode_graph is None or decode_graph.cache_key != key:
decode_graph = self._build_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
self.graph = decode_graph
else:
decode_graph.token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
return decode_graph
def _run_depth_decode_attention_core(
self,
layer_idx: int,
stage: _DepthDecodeCudaGraphLayerStage,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
cache_position: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
attention = self.backbone.transformer.blocks[layer_idx].self_attn
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
stage.key,
stage.value,
layer_idx,
cache_kwargs,
)
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
stage.query,
key_states,
value_states,
attn_mask=attention_bias,
dropout_p=0.0,
is_causal=False,
)
return attn_output.transpose(1, 2)
def run(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
past_length: int,
) -> tuple[torch.Tensor, Cache]:
end = past_length + 1
decode_graph = self._get_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
cache_position = decode_graph.positions[past_length:end]
attention_bias_q = attention_bias[:, :, past_length:end, :end]
decode_graph.pre_graph.replay()
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
attn_context = self._run_depth_decode_attention_core(
layer_idx,
decode_graph.stages[layer_idx],
past_key_values=past_key_values,
attention_bias=attention_bias_q,
cache_position=cache_position,
cos=decode_graph.cos,
sin=decode_graph.sin,
)
post_graph.attn_context.copy_(attn_context)
post_graph.graph.replay()
return decode_graph.output, past_key_values
def _cuda_graph_tensor_signature(
tensor: torch.Tensor | None,
) -> tuple[Any, ...] | None:
if tensor is None:
return None
return (
tuple(tensor.shape),
tuple(tensor.stride()),
str(tensor.dtype),
str(tensor.device),
)
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
sig(context.cross_mask),
sig(context.self_mask),
sig(context.valid_action),
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
)
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return tuple(
(
sig(step.conditioning),
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
tuple(sig(t) for t in step.final_modulation),
)
for step in modulations
)
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
sig(inputs.trajectory),
_cuda_graph_context_signature(inputs.context),
_cuda_graph_modulation_signature(inputs.modulations),
sig(inputs.action_dim_is_pad),
int(steps),
)
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None:
return None
static = torch.empty_strided(
tuple(tensor.shape),
tuple(tensor.stride()),
device=tensor.device,
dtype=tensor.dtype,
)
static.copy_(tensor)
return static
def _clone_static_context(context: Any) -> Any:
rope_cache = None
if context.rope_cache is not None:
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
return context.__class__(
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
cross_mask=_clone_static_tensor(context.cross_mask),
self_mask=_clone_static_tensor(context.self_mask),
valid_action=_clone_static_tensor(context.valid_action),
rope_cache=rope_cache,
)
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
return tuple(
step.__class__(
conditioning=_clone_static_tensor(step.conditioning),
block_modulations=tuple(
tuple(_clone_static_tensor(t) for t in block_modulation)
for block_modulation in step.block_modulations
),
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
)
for step in modulations
)
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
return _ActionFlowInputs(
trajectory=_clone_static_tensor(inputs.trajectory),
context=_clone_static_context(inputs.context),
modulations=_clone_static_modulations(inputs.modulations),
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
)
def _copy_context_(dst: Any, src: Any) -> None:
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
dst_k.copy_(src_k)
dst_v.copy_(src_v)
if src.cross_mask is not None:
dst.cross_mask.copy_(src.cross_mask)
if src.self_mask is not None:
dst.self_mask.copy_(src.self_mask)
if src.valid_action is not None:
dst.valid_action.copy_(src.valid_action)
if src.rope_cache is not None:
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
dst_tensor.copy_(src_tensor)
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
dst.trajectory.copy_(src.trajectory)
_copy_context_(dst.context, src.context)
if src.action_dim_is_pad is not None:
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _capture_cuda_graph(
fn,
device: torch.device,
*,
after_warmup=None,
) -> tuple[torch.cuda.CUDAGraph, Any]:
warmup_stream = torch.cuda.Stream(device=device)
warmup_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(warmup_stream):
fn()
torch.cuda.current_stream(device).wait_stream(warmup_stream)
if after_warmup is not None:
after_warmup()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = fn()
return graph, output

File diff suppressed because it is too large Load Diff

View File

@@ -1,431 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
import numpy as np
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
Unpack,
ProcessingKwargs,
ProcessorMixin,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.utils import logging
from transformers import AutoTokenizer
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
logger = logging.get_logger(__name__)
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
IMAGE_TOKENS = [
IMAGE_PATCH_TOKEN,
IM_COL_TOKEN,
IM_START_TOKEN,
LOW_RES_IMAGE_START_TOKEN,
FRAME_START_TOKEN,
IM_END_TOKEN,
FRAME_END_TOKEN,
IMAGE_LOW_RES_TOKEN,
]
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
"""MolmoAct2 processor kwargs"""
images_kwargs: MolmoAct2ImagesKwargs
videos_kwargs: MolmoAct2VideoProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": True,
},
"videos_kwargs": {"return_metadata": True},
}
class MolmoAct2Processor(ProcessorMixin):
attributes = ["image_processor", "video_processor", "tokenizer"]
optional_attributes = [
"chat_template",
"time_mode",
"image_use_col_tokens",
"use_single_crop_col_tokens",
"use_single_crop_start_token",
"video_use_col_tokens",
"use_frame_special_tokens",
]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor: MolmoAct2ImageProcessor = None,
video_processor: MolmoAct2VideoProcessor = None,
tokenizer: AutoTokenizer = None,
chat_template: str | None = None,
image_use_col_tokens: bool | None = True,
use_single_crop_col_tokens: bool | None = None,
use_single_crop_start_token: bool | None = True,
video_use_col_tokens: bool | None = False,
use_frame_special_tokens: bool | None = True,
**kwargs,
) -> None:
super().__init__(
image_processor,
video_processor,
tokenizer,
chat_template=chat_template,
)
self.image_use_col_tokens = image_use_col_tokens
self.use_single_crop_col_tokens = use_single_crop_col_tokens
self.use_single_crop_start_token = use_single_crop_start_token
self.video_use_col_tokens = video_use_col_tokens
self.use_frame_special_tokens = use_frame_special_tokens
self.image_placeholder_token = IMAGE_PROMPT
self.video_placeholder_token = VIDEO_PROMPT
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
def get_image_tokens(self, image_grid: np.ndarray):
resized_h, resized_w, height, width = image_grid
if int(height) == 0 or int(width) == 0:
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
]
return np.concatenate(joint)
per_row = np.full(width, IMAGE_PATCH_TOKEN)
if self.image_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [height]),
[IM_END_TOKEN],
]
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[image_start_token],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
] + joint
return np.concatenate(joint)
def get_video_string(
self,
video_grid: np.ndarray,
timestamps: np.ndarray,
):
if self.use_frame_special_tokens:
start_token_id = FRAME_START_TOKEN
end_token_id = FRAME_END_TOKEN
else:
start_token_id = IM_START_TOKEN
end_token_id = IM_END_TOKEN
num_frames, h, w = video_grid
video_string: str = ""
for frame_idx, frame_time in enumerate(timestamps):
# `per-frame-compact` time mode
prev_space = " " if frame_idx > 0 else ""
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
video_string += frame_prefix
per_row = np.full(w, IMAGE_PATCH_TOKEN)
if self.video_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
extra_tokens = np.tile(per_row, [h])
video_tokens = [
[start_token_id],
extra_tokens,
[end_token_id],
]
video_string += "".join(np.concatenate(video_tokens, 0))
return video_string
def insert_bos(
self,
input_ids: np.ndarray,
attention_mask: np.ndarray,
bos_token_id: int,
pad_token_id: int,
):
"""
Args:
input_ids: [B, S] array with left padding
attention_mask: [B, S] array (0 for pad, 1 for valid)
bos_token_id: int
pad_token_id: int
Returns:
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
attention_mask_out: same shape as input_ids_out
"""
need_to_expand = len(input_ids.shape) == 1
if need_to_expand:
input_ids = input_ids[None, :]
attention_mask = attention_mask[None, :]
B, S = input_ids.shape
# Handle zero-length sequence
if S == 0:
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
if bos_already_present:
if need_to_expand:
input_ids = input_ids[0]
attention_mask = attention_mask[0]
return input_ids, attention_mask
else:
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
tgt_idx = src_idx + 1 # shit right
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
# flatten valid_positions
flat_vals = input_ids[valid_mask]
flat_batch = batch_idx[valid_mask]
flat_tgt = tgt_idx[valid_mask]
new_input_ids[flat_batch, flat_tgt] = flat_vals
new_attention_mask[flat_batch, flat_tgt] = 1
insert_pos = first_valid_index
new_input_ids[np.arange(B), insert_pos] = bos_token_id
new_attention_mask[np.arange(B), insert_pos] = 1
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
def __call__(
self,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
images: ImageInput = None,
videos: VideoInput = None,
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
) -> BatchFeature:
"""
Args:
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
- `"timestamps"`: `np.ndarray` of shape (T,)
- `"sampled_fps"`: `float` (optional)
- `"sampling_augmentation"`: `str` (optional)
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
`BatchFeature`: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
Returned when `images` is not `None`.
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
Returned when `videos` is not `None`.
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
MolmoAct2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
image_grids = image_inputs["image_grids"]
else:
image_inputs = {}
image_grids = None
if videos is not None:
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grids = videos_inputs["video_grids"]
# If user has not requested video metadata, pop it
if "return_metadata" not in kwargs:
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
else:
videos_inputs = {}
video_grids = None
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
if image_grids is not None:
index = 0
for i in range(len(text)):
num_images = text[i].count(self.image_placeholder_token)
image_grids_i = image_grids[index : index + num_images]
for image_grid in image_grids_i:
image_tokens = self.get_image_tokens(image_grid)
image_string = "".join(image_tokens)
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
index += num_images
if video_grids is not None:
index = 0
for i in range(len(text)):
num_videos = text[i].count(self.video_placeholder_token)
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
video_string = self.get_video_string(
video_grid,
metadata.timestamps,
)
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
index += num_videos
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
input_ids = text_inputs["input_ids"]
attention_mask = text_inputs["attention_mask"]
input_ids = np.array(input_ids)
attention_mask = np.array(attention_mask)
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
input_ids, attention_mask = self.insert_bos(
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
)
if return_mm_token_type_ids:
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
text_inputs["token_type_ids"] = token_type_ids.tolist()
text_inputs["input_ids"] = input_ids.tolist()
text_inputs["attention_mask"] = attention_mask.tolist()
return BatchFeature(
data={**text_inputs, **image_inputs, **videos_inputs},
tensor_type=return_tensors,
)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
MolmoAct2Processor.register_for_auto_class()

View File

@@ -1,997 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
from contextlib import redirect_stdout
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union
from collections.abc import Callable
import numpy as np
import requests
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
SizeDict,
validate_kwargs,
)
from transformers.video_utils import (
VideoInput,
is_valid_video,
make_batched_videos,
make_batched_metadata,
VideoMetadata,
)
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import (
is_av_available,
is_decord_available,
is_torchcodec_available,
is_yt_dlp_available,
TensorType,
logging,
to_numpy,
)
logger = logging.get_logger(__name__)
MAX_VIDEO_FPS = 8
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
if len(image.shape) == 3:
is_video = False
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
else:
is_video = True
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
if is_video:
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
else:
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: ImageInput,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
pooling_w = image_pooling_w
pooling_h = image_pooling_h
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [h, w]
return (
image_grid,
batch_pixels_to_patches(resized, image_patch_size),
pooling_idx,
)
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
if video_fps % sampling_fps != 0:
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
def read_video_decord(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the Decord backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
import importlib
decord = importlib.import_module("decord")
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
duration = time_stamps[-1][1] - time_stamps[0][0]
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="decord",
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
target_timestamps = np.array(target_timestamps)
offset = time_stamps[0, 0]
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
ix = np.minimum(ix, len(time_stamps) - 1)
video = vr.get_batch(ix).asnumpy()
metadata.update(
{
"frames_indices": target_timestamps * video_fps,
"height": video.shape[1],
"width": video.shape[2],
}
)
return video, metadata
def read_video_torchcodec(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using torchcodec decoder.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
torchcodec = importlib.import_module("torchcodec")
decoder = torchcodec.decoders.VideoDecoder(
video_path,
# Interestingly `exact` mode takes less than approximate when we load the whole video
seek_mode="exact",
# Allow FFmpeg decide on the number of threads for efficiency
num_ffmpeg_threads=0,
)
# If the first frame starts at > 0, we effectively clip the video starting at that time
# since (most) video players would also skip to that time
time_offset = decoder.metadata.begin_stream_seconds_from_content
# Note this duration does assume we started playing at `time_offset`
duration = decoder.metadata.duration_seconds
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
duration=duration,
video_backend="torchcodec",
height=decoder.metadata.height,
width=decoder.metadata.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
# out-of-bounds, to handle this we sanity check then clip them
assert all(x >= 0 for x in target_timestamps)
assert all(x < duration + 1e-6 for x in target_timestamps)
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
# exact boundary value, we should still get the first/last frame anyway
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
# Note we avoid using numpy ops here to reduce floating precision issues
timestamps = [x + time_offset for x in target_timestamps]
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
video = (
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
) # Convert to THWC format
target_timestamps = np.array(target_timestamps)
metadata.frames_indices = target_timestamps * metadata.fps
return video, metadata
def read_video_pyav(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the PyAV backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
av = importlib.import_module("av")
with av.open(video_path) as container:
video_stream = container.streams.video[0]
fps = video_stream.average_rate or video_stream.guessed_rate
it = container.decode(video=0)
frames = list(it)
stream = container.streams.video[0]
start = frames[0].pts * stream.time_base
container_end = stream.duration
if container_end is not None:
container_end *= stream.time_base
if container_end is None or container_end < frames[-1].pts:
# Some problem with stream duration, so use the frame PTS directly
# and guess the duration of the last frame
end = frames[-1].pts * stream.time_base + 1 / fps
else:
end = container_end
duration = float(end - start)
metadata = VideoMetadata(
total_num_frames=len(frames),
fps=float(fps),
duration=float(duration),
video_backend="pyav",
height=video_stream.height,
width=video_stream.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
offset = float(start)
target_timestamps = np.array(target_timestamps)
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
indices = np.minimum(indices, len(end_time_stamps) - 1)
video = np.stack(
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
axis=0,
)
metadata.frames_indices = target_timestamps * fps
return video, metadata
VIDEO_DECODERS = {
"decord": read_video_decord,
"torchcodec": read_video_torchcodec,
"pyav": read_video_pyav,
}
def load_video(
video: VideoInput,
backend: str = "decord",
sample_timestamps_fn: Callable | None = None,
**kwargs,
):
"""
Loads `video` to a numpy array.
Args:
video (`VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
backend (`str`, *optional*, defaults to `"decord"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
"""
# Early exit if provided an array or `PIL` frames
if not isinstance(video, str):
metadata = [None] * len(video)
return video, metadata
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
import importlib
yt_dlp = importlib.import_module("yt_dlp")
buffer = BytesIO()
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video, timeout=10).content)
elif os.path.isfile(video):
file_obj = video
else:
raise TypeError(
"Incorrect format used for video. Should be an url linking to an video or a local path."
)
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend == "opencv":
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
if (
(not is_decord_available() and backend == "decord")
or (not is_torchcodec_available() and backend == "torchcodec")
or (not is_av_available() and backend == "pyav")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
return video, metadata
def get_target_fps(
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: tuple[float],
) -> float:
"""
Get the target fps that best spans the video and has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if "uniform" in frame_sample_mode:
if num_frames_sampled_at_fps > max_frames:
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames, choose the one with higher density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
if selected_target_fps is None:
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
patch_size: int | None
pooling_size: list[int] | None
frame_sample_mode: str | None
max_fps: int | None
sampling_fps: int | None
class MolmoAct2VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BILINEAR
size = {"height": 378, "width": 378}
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
patch_size = 14
pooling_size = [3, 3]
do_sample_frames = True
frame_sample_mode = "uniform_last_frame"
max_fps = 2
sampling_fps = 2
valid_kwargs = MolmoAct2VideoProcessorKwargs
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
super().__init__(**kwargs)
if self.size is not None and (
self.size.get("height", None) is None or self.size.get("width", None) is None
):
raise ValueError("size must contain 'height' and 'width' keys.")
def _further_process_kwargs(
self,
size: SizeDict | None = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if size is not None and ("height" not in size or "width" not in size):
raise ValueError("size must contain 'height' and 'width' keys.")
return super()._further_process_kwargs(size=size, **kwargs)
def sample_times(
self,
metadata: VideoMetadata,
frame_sample_mode: str,
num_frames: int,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Time-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
man_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
duration = metadata.duration or metadata.total_num_frames / metadata.fps
if frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
# Try larger and larger FPSs until we hit one that can't span the video
target_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if num_frames / candidate_fps < duration:
break
target_fps = candidate_fps
times = np.arange(0, num_frames) / target_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= num_frames
else:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
return times
else:
raise NotImplementedError(frame_sample_mode)
def sample_frames(
self,
metadata: VideoMetadata,
frame_sample_mode: str | None = None,
num_frames: int | None = None,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Frame-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
max_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
total_num_frames = metadata.total_num_frames
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
duration = total_num_frames / metadata.fps
if total_num_frames <= 2:
return np.arange(total_num_frames).astype(int)
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(metadata.fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
return indices
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
elif frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
selected_target_fps = get_target_fps(
metadata.fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
metadata.fps,
)
return indices
else:
raise NotImplementedError(frame_sample_mode)
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
raise ImportError(
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
)
if is_decord_available():
backend = "decord"
elif is_torchcodec_available():
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `torchcodec`."
)
backend = "torchcodec"
else:
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `PyAV`."
)
backend = "pyav"
if isinstance(video_url_or_urls, list):
return list(
zip(
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
)
)
else:
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
def _decode_and_sample_videos(
self,
videos: VideoInput,
video_metadata: VideoMetadata | dict,
do_sample_frames: bool | None = None,
sample_indices_fn: Callable | None = None,
sample_timestamps_fn: Callable | None = None,
):
"""
Decode input videos and sample frames if needed.
"""
videos = make_batched_videos(videos)
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
# Framed-based sampling if an array video is passed
# Otherwise, time-based sampling with decoding
if is_valid_video(videos[0]) and do_sample_frames:
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
sampled_videos = []
sampled_metadata = []
for video, metadata in zip(videos, video_metadata):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
sampled_metadata.append(metadata)
videos = sampled_videos
video_metadata = sampled_metadata
elif not is_valid_video(videos[0]):
if sample_indices_fn is None:
logger.warning(
"do_sample_frames is False, but video array is not provided: "
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
)
if isinstance(videos[0], list):
raise ValueError("A list of images is not supported for video input!")
else:
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
return videos, video_metadata
def _prepare_input_videos(
self,
videos: VideoInput,
**kwargs,
) -> list[np.ndarray]:
processed_videos = [to_numpy(video) for video in videos]
return processed_videos
def preprocess(
self,
videos: VideoInput,
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
do_sample_frames = kwargs.pop("do_sample_frames")
video_metadata = kwargs.pop("video_metadata")
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
sample_timestamps_fn = partial(self.sample_times, **kwargs)
videos, video_metadata = self._decode_and_sample_videos(
videos,
video_metadata=video_metadata,
do_sample_frames=do_sample_frames,
sample_indices_fn=sample_indices_fn,
sample_timestamps_fn=sample_timestamps_fn,
)
videos = self._prepare_input_videos(videos=videos)
kwargs = self._further_process_kwargs(**kwargs)
return_metadata = kwargs.pop("return_metadata")
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
if return_metadata:
preprocessed_videos["video_metadata"] = video_metadata
return preprocessed_videos
def _preprocess(
self,
videos: list[np.ndarray],
size: SizeDict | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess a video for the model.
Args:
videos (`VideoInput`):
Video to preprocess.
size (`SizeDict`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values_videos`: The preprocessed videos.
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
- `video_grids`: The video grids.
"""
if size.height is None or size.width is None:
raise ValueError("size must contain 'height' and 'width' keys.")
base_image_input_size = [size.height, size.width]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
for video in videos:
all_crops = []
pooled_patches_idx = []
for frame in video:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
frame,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
)
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
pooled_patches_idx.append(pooled_idx_with_offset)
all_crops.append(crops)
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
all_crops = np.concatenate(all_crops, 0)
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
batch_grids.append(video_grid)
batch_crops.append(all_crops)
batch_pooled_patches_idx.append(pooled_patches_idx)
video_grids = np.stack(batch_grids, 0)
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2VideoProcessor.register_for_auto_class()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -142,15 +141,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -237,13 +227,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -265,16 +258,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = rotary_emb(dummy_tensor, position_ids)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma_layer.self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -282,13 +274,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma_layer.self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -496,9 +488,8 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -508,39 +499,36 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -919,7 +907,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = clone_past_key_values(past_key_values)
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -29,7 +30,6 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,7 +41,6 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -139,15 +138,6 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -234,13 +224,16 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -262,16 +255,15 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = rotary_emb(dummy_tensor, position_ids)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma_layer.self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -279,13 +271,13 @@ def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_c
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma_layer.self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = layers[i]
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -449,13 +441,13 @@ class PaliGemmaWithExpertModel(
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -493,9 +485,8 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -505,39 +496,36 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
layers=layers,
rotary_emb=rotary_emb,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -674,7 +662,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
return lang_emb
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
@@ -892,7 +881,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = clone_past_key_values(past_key_values)
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -1 +0,0 @@
../../../../docs/source/policy_vla_jepa_README.md

View File

@@ -1,23 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"make_vla_jepa_pre_post_processors",
]

View File

@@ -1,337 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions

View File

@@ -1,154 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,629 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] — multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str — task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model

View File

@@ -1,155 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -1,117 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)

View File

@@ -1,418 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)

View File

@@ -95,13 +95,6 @@ from .relative_action_processor import (
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
# `lerobot.datasets.language`, which requires the `[dataset]` extra
# (`datasets`, `pyarrow`). Importing it from the processor package would
# break every base-install consumer of `lerobot.processor`. Users that
# need it import directly:
# from lerobot.processor.render_messages_processor import RenderMessagesStep
__all__ = [
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",

View File

@@ -174,24 +174,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(

View File

@@ -153,30 +153,26 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
return x
_COMPLEMENTARY_KEYS = (
"task",
"index",
"task_index",
"episode_index",
"timestamp",
"language_persistent",
"language_events",
"messages",
"message_streams",
"target_message_indices",
)
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""Extract complementary data from a batch dictionary.
"""
Extract complementary data from a batch dictionary.
Includes padding flags (any key containing ``_is_pad``) plus the fixed
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
each only when present in ``batch``.
This includes padding flags, task description, and indices.
Args:
batch: The batch dictionary.
Returns:
A dictionary with the extracted complementary data.
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
return {**pad_keys, **extras}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(

View File

@@ -1,84 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.utils import unwrap_scalar
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=unwrap_scalar(timestamp),
sample_idx=int(unwrap_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features

View File

@@ -20,14 +20,14 @@ from .factory import (
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
"TOPRewardConfig",
# Base class
"PreTrainedRewardModel",
# Factory functions

View File

@@ -17,11 +17,10 @@ import logging
import torch
from torch import Tensor, nn
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.constants import OBS_IMAGE, REWARD
from ..pretrained import PreTrainedRewardModel
from .configuration_classifier import RewardClassifierConfig
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""

View File

@@ -25,8 +25,7 @@ from lerobot.processor import (
policy_action_to_transition,
transition_to_policy_action,
)
from .configuration_classifier import RewardClassifierConfig
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
def make_classifier_processor(

View File

@@ -22,11 +22,10 @@ import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .pretrained import PreTrainedRewardModel
from .sarm.configuration_sarm import SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
@@ -38,7 +37,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm", "topreward".
"sarm", "robometer".
Returns:
The reward model class corresponding to the given name.
@@ -54,10 +53,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "topreward":
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
elif name == "robometer":
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
return TOPRewardModel
return RobometerRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
@@ -74,7 +73,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm", "topreward".
"reward_classifier", "sarm", "robometer".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -87,8 +86,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
elif reward_type == "robometer":
return RobometerConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
@@ -168,11 +167,10 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(reward_cfg, RobometerConfig):
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
elif isinstance(reward_cfg, TOPRewardConfig):
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
return make_topreward_pre_post_processors(
return make_robometer_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

View File

@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_topreward import TOPRewardConfig
from .modeling_topreward import TOPRewardModel
from .processor_topreward import make_topreward_pre_post_processors
from .configuration_robometer import RobometerConfig
from .modeling_robometer import RobometerRewardModel
from .processor_robometer import make_robometer_pre_post_processors
__all__ = ["TOPRewardConfig", "TOPRewardModel", "make_topreward_pre_post_processors"]
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]

View File

@@ -0,0 +1,229 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Upstream/legacy Robometer checkpoint loader.
This module is **only** used by the one-time conversion tooling
(:mod:`lerobot.scripts.lerobot_export_robometer` and
``scripts/verify_robometer_export.py``). It supports:
- Sharded upstream checkpoints (``model-0000X-of-Y.safetensors`` + index).
- PEFT/LoRA adapter checkpoints (``adapter_config.json`` + adapter weights).
- Local snapshot directories or Hugging Face Hub repo ids.
Once :class:`~lerobot.rewards.robometer.RobometerRewardModel` is loaded
through this module, calling ``save_pretrained`` writes the canonical
LeRobot-native layout (single ``model.safetensors`` + ``config.json``) that
the base loader understands.
The runtime path
(:meth:`~lerobot.rewards.pretrained.PreTrainedRewardModel.from_pretrained`)
does **not** import this file. It is safe to delete once you no longer need
the conversion tooling.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from torch import Tensor, nn
from lerobot.utils.import_utils import require_package
logger = logging.getLogger(__name__)
def _download_robometer_snapshot(
pretrained_path: str,
*,
hub_token: str | None = None,
) -> Path:
"""Resolve a Robometer snapshot directory.
- If ``pretrained_path`` is an existing local directory, return it directly.
- Otherwise treat ``pretrained_path`` as a Hugging Face repo id (optionally
with ``@revision``) and download it via ``snapshot_download``.
"""
local_candidate = Path(pretrained_path)
if local_candidate.is_dir():
return local_candidate
if "@" in pretrained_path:
repo_id, revision = pretrained_path.split("@", 1)
else:
repo_id, revision = pretrained_path, None
return Path(
snapshot_download(
repo_id=repo_id,
revision=revision,
token=hub_token,
allow_patterns=[
"*.json",
"*.safetensors",
"*.bin",
"*.txt",
"*.model",
"tokenizer*",
"special_tokens_map.json",
],
)
)
def _maybe_apply_peft(base_model: Any, snapshot_dir: Path) -> Any:
adapter_config = snapshot_dir / "adapter_config.json"
if not adapter_config.exists():
return base_model
require_package("peft", extra="peft-dep")
from peft import PeftModel
return PeftModel.from_pretrained(base_model, str(snapshot_dir))
def _remap_state_dict_keys(state_dict: dict[str, Tensor], model: nn.Module) -> dict[str, Tensor]:
"""Try a few common prefix swaps so PEFT-wrapped checkpoints load cleanly."""
model_keys = set(model.state_dict().keys())
remapped: dict[str, Tensor] = {}
for key, value in state_dict.items():
if key in model_keys:
remapped[key] = value
continue
candidates: list[str] = []
if key.startswith("model.model."):
candidates.append(key.replace("model.model.", "model.base_model.model.model.", 1))
candidates.append(key.replace("model.model.", "model.", 1))
if key.startswith("model."):
candidates.append(f"model.{key}")
candidates.append(key.replace("model.", "", 1))
else:
candidates.append(f"model.{key}")
if key.startswith("model.") and not key.startswith("model.base_model."):
parts = key.split(".", 1)
if len(parts) == 2:
candidates.append(f"model.base_model.{parts[1]}")
for candidate in candidates:
if candidate in model_keys:
remapped[candidate] = value
break
else:
remapped[key] = value
return remapped
def _resolve_checkpoint_safetensors_files(snapshot_dir: Path) -> list[Path]:
"""Pick the safetensors files that hold the full model weights.
When ``model.safetensors.index.json`` is present, only the files it lists are
loaded. Otherwise any ``model*.safetensors`` shards are preferred over
sidecar files. Falls back to every ``*.safetensors`` in the snapshot.
"""
index_path = snapshot_dir / "model.safetensors.index.json"
if index_path.exists():
with index_path.open() as f:
weight_map = json.load(f).get("weight_map", {})
indexed = sorted(
{snapshot_dir / name for name in weight_map.values() if (snapshot_dir / name).exists()}
)
if indexed:
return indexed
model_shards = sorted(snapshot_dir.glob("model*.safetensors"))
if model_shards:
return model_shards
return sorted(snapshot_dir.glob("*.safetensors"))
def apply_upstream_checkpoint(
model: nn.Module,
pretrained_path: str,
*,
hub_token: str | None = None,
) -> None:
"""Load an upstream (sharded / PEFT) Robometer checkpoint into ``model``.
Downloads the snapshot, optionally applies PEFT wrapping, merges sharded
``.safetensors`` files in memory, remaps PEFT-prefixed keys, and loads them
into ``model`` non-strictly. ``model`` must already be constructed with the
matching Robometer architecture (e.g. via
:class:`~lerobot.rewards.robometer.RobometerRewardModel` ``__init__``).
"""
snapshot_dir = _download_robometer_snapshot(pretrained_path, hub_token=hub_token)
# PEFT adapter checkpoints wrap the base model before weight loading so the
# remapper can place adapter tensors at the right prefix.
base_model = getattr(model, "model", None)
if base_model is not None:
wrapped = _maybe_apply_peft(base_model, snapshot_dir)
if wrapped is not base_model:
model.model = wrapped
files = _resolve_checkpoint_safetensors_files(snapshot_dir)
if not files:
logger.warning("No *.safetensors files in %s; using freshly initialised heads", snapshot_dir)
return
merged: dict[str, Tensor] = {}
for path in files:
merged.update(load_file(str(path)))
remapped = _remap_state_dict_keys(merged, model)
# Defensive vocab-match. With the corrected resize logic
# (``_resize_embeddings_for_robometer`` uses ``len(tokenizer) + 5``),
# a freshly built ``RobometerRewardModel`` should already share the same
# vocabulary as the upstream checkpoint (e.g. 151,674 for
# ``robometer/Robometer-4B``). This block stays in place as a safety net
# in case a future upstream variant uses a different vocab — we never
# want ``load_state_dict`` to trip on a silent shape mismatch.
base_model = getattr(model, "model", None)
if base_model is not None and hasattr(base_model, "get_input_embeddings"):
for key in (
"model.model.language_model.embed_tokens.weight",
"model.language_model.embed_tokens.weight",
"model.embed_tokens.weight",
):
tensor = remapped.get(key)
if tensor is None:
continue
ckpt_vocab = int(tensor.shape[0])
current_vocab = int(base_model.get_input_embeddings().num_embeddings)
if ckpt_vocab != current_vocab:
logger.info(
"Resizing model embed table %d -> %d to match upstream checkpoint vocab "
"(upstream was trained against a different Qwen revision).",
current_vocab,
ckpt_vocab,
)
base_model.resize_token_embeddings(ckpt_vocab)
break
missing, unexpected = model.load_state_dict(remapped, strict=False)
if missing:
logger.debug("Robometer checkpoint missing %d keys (sample: %s)", len(missing), missing[:5])
if unexpected:
logger.debug(
"Robometer checkpoint had %d unexpected keys (sample: %s)", len(unexpected), unexpected[:5]
)

Some files were not shown because too many files have changed in this diff Show More