mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
59 Commits
feat/robom
...
c2b29c8ae0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2b29c8ae0 | ||
|
|
58eac863aa | ||
|
|
952e5146dc | ||
|
|
37fda2a6fc | ||
|
|
df7d5132d1 | ||
|
|
8efa5cabe9 | ||
|
|
7e23859c55 | ||
|
|
a24f669deb | ||
|
|
b7727b8a6c | ||
|
|
7db4414e6b | ||
|
|
7da594fda8 | ||
|
|
01ce5d7af1 | ||
|
|
83ef59e020 | ||
|
|
997b713f14 | ||
|
|
1e3d25f10e | ||
|
|
8c0efb8295 | ||
|
|
9ef0fd5433 | ||
|
|
26d2ac48a8 | ||
|
|
0f29cd3167 | ||
|
|
64c9570547 | ||
|
|
8b03d25fef | ||
|
|
5e37d97631 | ||
|
|
a71e0d34ad | ||
|
|
d00b3e993a | ||
|
|
596c72bfc6 | ||
|
|
ea535ad98d | ||
|
|
7368a0085a | ||
|
|
7dba4f19a9 | ||
|
|
90d398ea59 | ||
|
|
16a4643000 | ||
|
|
6fa78ca8b4 | ||
|
|
ec4bf4e47f | ||
|
|
da56489174 | ||
|
|
addb354296 | ||
|
|
848ed3240e | ||
|
|
f7bb1795e7 | ||
|
|
0355902fba | ||
|
|
24017e960c | ||
|
|
e86f5af5bf | ||
|
|
5c98e80430 | ||
|
|
f65f3f7a4a | ||
|
|
8194897994 | ||
|
|
9f437d86b6 | ||
|
|
b74a551d38 | ||
|
|
c0a2e9814d | ||
|
|
bac4f61eae | ||
|
|
f4b834844e | ||
|
|
dfdc48a7f1 | ||
|
|
6a8878a639 | ||
|
|
d38eb89f71 | ||
|
|
7ab4936b1b | ||
|
|
ca8c60a0ed | ||
|
|
3c15fd8537 | ||
|
|
5ebbdf3d05 | ||
|
|
6e035fb169 | ||
|
|
01dcb4c292 | ||
|
|
bd9619dfc3 | ||
|
|
0a4a7c40ad | ||
|
|
ca9028ad64 |
@@ -3,6 +3,8 @@
|
||||
title: LeRobot
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: cheat-sheet
|
||||
title: Cheat sheet
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: il_robots
|
||||
@@ -37,8 +39,12 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
@@ -53,6 +59,8 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
@@ -67,6 +75,8 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
@@ -139,6 +149,8 @@
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
- local: rebot_b601
|
||||
title: reBot B601-DM
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
|
||||
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
139
docs/source/cheat-sheet.mdx
Normal file
139
docs/source/cheat-sheet.mdx
Normal file
@@ -0,0 +1,139 @@
|
||||
# Cheat sheet
|
||||
|
||||
All of the LeRobot commands in one place. If you forgot how to use a specific command or want to learn about a new one you can do it here.
|
||||
|
||||
> [!WARNING]
|
||||
> For all of the commands listed below remember to change the ports/names/ids to your own values!
|
||||
|
||||
> [!TIP]
|
||||
> Another great way to look at all the commands and get them configured for your specific setup is to use this [Jupyter Notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb).
|
||||
|
||||
### Setup and installation
|
||||
|
||||
For installation please look at [LeRobot Installation](https://huggingface.co/docs/lerobot/main/en/installation).
|
||||
|
||||
### Useful tools
|
||||
|
||||
###### Find port
|
||||
|
||||
Use this to identify which serial ports your robots are connected to. Follow the instructions in your terminal: you will be asked to unplug the USB cable and press Enter. The script will then detect and print the correct serial port for that robot.
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
###### Find cameras
|
||||
|
||||
Quickly find camera indices and verify their output. This command prints camera information to the terminal and saves test frames from each detected camera to `lerobot/outputs/captured_images`
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
In most cases you will need to perform calibration just once for each robot and teleoperation device. Before performing the calibration make sure that all the joints are roughly in the middle position.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm
|
||||
```
|
||||
|
||||
Make sure that you use the same IDs used during calibration later for the other scripts. That's how LeRobot finds the calibration files.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
Teleoperating with two cameras and displaying the data with Rerun.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
### Recording a dataset
|
||||
|
||||
The dataset is automatically uploaded to the server and saved under repo_id, make sure you are logged in to your HF account with CLI:
|
||||
`hf auth login`
|
||||
|
||||
You can get the token from: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--dataset.num_episodes=30 \
|
||||
--dataset.single_task="put the red brick in a bowl" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
While collecting the dataset you can control the process with your keyboard:
|
||||
Control the data recording flow using keyboard shortcuts:
|
||||
|
||||
- Press **Right Arrow (`→`)**: Save episode and move to the next.
|
||||
- Press **Left Arrow (`←`)**: Delete current episode and retry.
|
||||
- Press **Escape (`ESC`)**: Stop, encode videos, and upload.
|
||||
|
||||
### Training
|
||||
|
||||
Depending on your hardware training the policy might take a few hours. That's how you train simple `ACT` policy:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--job_name=act_so101_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
- Policy Types: `act`, `diffusion`, `smolvla`, `pi05`
|
||||
- Devices: `cuda` (NVIDIA), `mps` (Apple Silicon), `cpu`
|
||||
|
||||
If you want to fine-tune a specific model you can provide the path to the model. In this case path is enough and type can be skipped.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.path=username/the_policy_to_finetune \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
|
||||
|
||||
> [!TIP]
|
||||
> If you are using the previous release V0.5.1 instead of `lerobot-rollout` you need to use `lerobot-record`. More information [here](https://huggingface.co/docs/lerobot/v0.5.1/en/il_robots#run-inference-and-evaluate-your-policy).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,14 +121,12 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -193,7 +207,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
147
docs/source/language_and_recipes.mdx
Normal file
147
docs/source/language_and_recipes.mdx
Normal file
@@ -0,0 +1,147 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float32 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
@@ -10,6 +10,7 @@ This docs will guide you to:
|
||||
- Stream datasets without downloading using `StreamingLeRobotDataset`
|
||||
- Apply image transforms for data augmentation during training
|
||||
- Migrate existing `v2.1` datasets to `v3.0`
|
||||
- Experiment with other `LeRobotDataset` formats and implementations like Lance
|
||||
|
||||
## What’s new in `v3`
|
||||
|
||||
@@ -43,7 +44,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
@@ -315,3 +316,39 @@ Dataset v3.0 uses incremental parquet writing with buffered metadata for efficie
|
||||
- Ensures the dataset is valid for loading
|
||||
|
||||
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
|
||||
|
||||
## Other formats and implementations
|
||||
|
||||
### Lance
|
||||
|
||||
Lance is a useful format for multimodal AI datasets, especially for large-scale training requiring high performance IO and random access.
|
||||
|
||||
The `lerobot-lancedb` package implements `LeRobotLanceDataset` (for JPEG images) and `LeRobotLanceVideoDataset` (for mp4 videos).
|
||||
Those two storage layouts both subclass LeRobotDataset and can provide data loading speed ups.
|
||||
|
||||
`LeRobotLanceDataset` is a drop-in replacement for `LeRobotDataset`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot_lancedb import LeRobotLanceDataset, LeRobotLanceVideoDataset
|
||||
|
||||
cfg = DiffusionConfig(...)
|
||||
meta = LeRobotDatasetMetadata(root=local_dataset_path) # or use repo_id=... to load metadata from the Hub
|
||||
delta_timestamps = {...}
|
||||
|
||||
# Use LeRobotLanceDataset for image datasets
|
||||
dataset = LeRobotLanceDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
# Or use LeRobotLanceVideoDataset for video datasets:
|
||||
dataset = LeRobotLanceVideoDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
```
|
||||
|
||||
Join the discussion on [Github](https://github.com/huggingface/lerobot/issues/3608) and explore the `lerobot-lancedb` documentation [here](https://lancedb.github.io/lerobot-lancedb/).
|
||||
|
||||
433
docs/source/molmoact2.mdx
Normal file
433
docs/source/molmoact2.mdx
Normal file
@@ -0,0 +1,433 @@
|
||||
# MolmoAct2 Policy
|
||||
|
||||
MolmoAct2 is the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
|
||||
training, evaluation, checkpointing, and dataset interfaces for easier use with
|
||||
LeRobot datasets.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Install LeRobot with the MolmoAct2 optional dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[molmoact2]"
|
||||
```
|
||||
|
||||
To run the models in this repository, you need an NVIDIA GPU. The measurements
|
||||
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
|
||||
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
|
||||
`gradient_checkpointing=true` and include the forward pass, backward pass,
|
||||
gradient clipping, optimizer step, and optimizer state allocation. Values are
|
||||
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
|
||||
dataloader workers, CUDA context, and fragmentation.
|
||||
|
||||
Multi-GPU training through `accelerate` increases throughput and global batch
|
||||
size, but this LeRobot port does not currently expose the original MolmoAct2
|
||||
`fsdp_devices` model-parallel training path. The current training script has
|
||||
not been tested for multi-node training.
|
||||
|
||||
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
|
||||
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
|
||||
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
|
||||
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
|
||||
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
|
||||
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
|
||||
|
||||
The repo has been tested with Ubuntu 22.04.
|
||||
|
||||
## Usage
|
||||
|
||||
To use MolmoAct2 in a LeRobot training config, set:
|
||||
|
||||
```python
|
||||
policy.type=molmoact2
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
|
||||
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
|
||||
the same LeRobot training loop, dataset transforms, checkpoint saving, and
|
||||
logging. The difference is only how the initial policy weights and processor
|
||||
state are loaded.
|
||||
|
||||
### Training With Original MolmoAct2 Weight
|
||||
|
||||
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
|
||||
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
|
||||
the original HF model files, then build its own policy processor from the
|
||||
dataset metadata and the policy options below.
|
||||
|
||||
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
|
||||
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
|
||||
augmentation, and LeRobot's checkpointing/logging path.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.setup_type="single franka robotic arm in libero" \
|
||||
--policy.control_mode="delta end-effector pose" \
|
||||
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.freeze_embedding=true \
|
||||
--policy.normalize_gripper=false \
|
||||
--policy.enable_knowledge_insulation=false \
|
||||
--policy.push_to_hub=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Training With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
|
||||
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
|
||||
restores the saved LeRobot policy config, model weights, processor, and
|
||||
normalization statistics. You can still override training-time options such as
|
||||
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.path=/path/to/pretrained_model \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Common Practices
|
||||
|
||||
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
|
||||
or a real-world dataset with less than 200 demonstrations, a global batch size of
|
||||
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
|
||||
cases, we intentionally keep the action expert fully trainable, which we found
|
||||
to be crucial for model performance. For larger fine-tuning datasets, larger
|
||||
global batch sizes and full fine-tuning are usually preferred.
|
||||
|
||||
### Common Policy Options
|
||||
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
|
||||
Use this for released MolmoAct2 weights.
|
||||
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
|
||||
created by LeRobot training.
|
||||
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
|
||||
`both`. `both` trains the flow-matching action expert and the discrete
|
||||
action-token loss.
|
||||
- `policy.train_action_expert_only`: trains only parameters whose names contain
|
||||
`action_expert`. It requires `policy.action_mode=continuous`.
|
||||
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
|
||||
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
|
||||
expert linear layers. When `policy.enable_lora_action_expert=false`, the
|
||||
action expert base weights remain fully trainable while the VLM is trained
|
||||
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
|
||||
action expert is also adapter-tuned instead of fully fine-tuned.
|
||||
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
|
||||
context K/V states before the action loss. The default is `false`.
|
||||
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
|
||||
`10`. This LeRobot port overrides the loaded checkpoint's
|
||||
`max_action_horizon` with this value.
|
||||
- `policy.n_action_steps`: number of actions consumed from each predicted
|
||||
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
|
||||
- `policy.setup_type`: text inserted into the prompt to describe the robot and
|
||||
scene, e.g. `single franka robotic arm in libero`. More examples are listed
|
||||
in the `metadata_by_tag` entries of
|
||||
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
|
||||
- `policy.control_mode`: text inserted into the prompt to describe the action
|
||||
space, e.g. `delta end-effector pose` or `absolute joint pose`.
|
||||
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
|
||||
processor.
|
||||
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
|
||||
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
|
||||
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
|
||||
example during training. We use `8` for fine-tuning.
|
||||
- `policy.num_inference_steps`: optional override for continuous action
|
||||
generation steps at inference time.
|
||||
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
|
||||
to reduce activation memory.
|
||||
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
|
||||
- `policy.normalize_gripper`: controls whether gripper dimensions are included
|
||||
in state/action quantile normalization. The default is `false`.
|
||||
- `policy.normalize_language`: normalizes task strings before prompt
|
||||
construction. The default is `true`.
|
||||
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
|
||||
Released checkpoints use `policy.expected_max_action_dim=32`.
|
||||
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
|
||||
infer it from images, state dimension, action dimension, action horizon, and
|
||||
discrete-action mode.
|
||||
|
||||
### Learning Rates
|
||||
|
||||
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
|
||||
fine-tuning experiments.
|
||||
|
||||
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
|
||||
`policy.optimizer_vit_lr=5e-6` for the vision tower,
|
||||
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
|
||||
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
|
||||
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
|
||||
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
|
||||
`policy.enable_lora_action_expert=false`, so the action expert is still fully
|
||||
fine-tuned with `policy.optimizer_action_expert_lr`. If
|
||||
`policy.enable_lora_action_expert=true`, the action expert is trained through
|
||||
LoRA adapters instead.
|
||||
- Action-expert-only fine-tuning trains only the action expert and uses
|
||||
`policy.optimizer_action_expert_lr=5e-5`.
|
||||
|
||||
You can override the full fine-tuning and action-expert learning rates with
|
||||
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
|
||||
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
|
||||
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
|
||||
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
|
||||
|
||||
### Dataset Quantile Statistics
|
||||
|
||||
MolmoAct2 defaults to quantile normalization for state and action features. If
|
||||
your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
Alternatively, train MolmoAct2 with mean/std normalization:
|
||||
|
||||
```bash
|
||||
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
|
||||
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
|
||||
fixed and use `policy.per_episode_seed=true`.
|
||||
|
||||
**Important:** We found that `num_steps_wait=10` does not reliably let the
|
||||
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
|
||||
results reported here use `num_steps_wait=50`.
|
||||
|
||||
### Evaluation With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
|
||||
normalization statistics are restored together with the model.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.use_amp=true \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
### Evaluation With Original MolmoAct2 Weight
|
||||
|
||||
You can evaluate a released Hugging Face checkpoint directly without first
|
||||
converting it to a LeRobot checkpoint. In this case, set
|
||||
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
|
||||
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
|
||||
normalization statistics, action horizon, prompt metadata, and image-key order
|
||||
from the checkpoint's `norm_stats.json`.
|
||||
|
||||
To fully replicate the MolmoAct2 paper results with released Hugging Face
|
||||
checkpoints, we recommend using the v0.5.1-pinned
|
||||
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
|
||||
branch. That branch matches the original evaluation settings used for the
|
||||
reported numbers.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.norm_tag=libero \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=float32 \
|
||||
--policy.use_amp=false \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_goal \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
|
||||
full LIBERO suite. The same command works for other released MolmoAct2
|
||||
checkpoints as long as the requested `policy.norm_tag` exists in that
|
||||
checkpoint's `norm_stats.json`.
|
||||
|
||||
### Common Evaluation Options
|
||||
|
||||
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
|
||||
flow-matching inference or `discrete` for action-token inference. It must be
|
||||
compatible with the training-time `policy.action_mode` saved in the
|
||||
checkpoint.
|
||||
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
|
||||
saved by LeRobot.
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
|
||||
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
|
||||
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
|
||||
image-key order, and action horizon from the original checkpoint's
|
||||
`norm_stats.json`. It is required for direct original-HF checkpoint
|
||||
evaluation.
|
||||
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
|
||||
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
|
||||
- `policy.use_amp`: runs the policy forward under autocast during eval. For
|
||||
`model_dtype=bfloat16`, keep this enabled.
|
||||
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
|
||||
graph path for faster repeated continuous-action rollout.
|
||||
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
|
||||
action generation deterministic per episode for replication.
|
||||
- `env.task`: comma-separated LIBERO suites or a single suite. Use
|
||||
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
|
||||
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
|
||||
by the policy processor.
|
||||
|
||||
## Performance Results
|
||||
|
||||
### LIBERO Benchmark Results
|
||||
|
||||
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
|
||||
compare and test its LeRobot implementation, we fine-tuned
|
||||
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
|
||||
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
|
||||
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
|
||||
results.
|
||||
|
||||
The LeRobot fine-tuned checkpoint reported here is available at
|
||||
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
|
||||
and was trained on
|
||||
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
|
||||
|
||||
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
|
||||
| -------------- | ---------------------: | -----------------: |
|
||||
| LIBERO Spatial | 98.4% | 97.8% |
|
||||
| LIBERO Object | 100.0% | 100.0% |
|
||||
| LIBERO Goal | 98.0% | 97.8% |
|
||||
| LIBERO 10 | 96.6% | 93.2% |
|
||||
| Average | 98.25% | 97.20% |
|
||||
|
||||
These results demonstrate MolmoAct2's strong performance across diverse robotic
|
||||
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
|
||||
evaluation section.
|
||||
|
||||
## Differences From the Original Implementation
|
||||
|
||||
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
|
||||
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
|
||||
differences from the original training repository are:
|
||||
|
||||
- The original paper training stack loads the model in fp32 and trains under
|
||||
mixed precision. This LeRobot port usually loads the checkpoint directly in
|
||||
`policy.model_dtype=bfloat16` for lower memory use.
|
||||
- The original repository uses its own FSDP/model-parallel training path. The
|
||||
LeRobot port uses the standard LeRobot/Accelerate training path and has not
|
||||
been tested for multi-node training.
|
||||
- The original repository supports sequence packing. The LeRobot port trains on
|
||||
one LeRobot sample per item and pads to an inferred fixed sequence budget.
|
||||
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
|
||||
dataset transforms, image augmentation, and Weights & Biases logging
|
||||
conventions.
|
||||
- The original training path supports mixed action horizons by padding to
|
||||
`max_action_horizon` and masking padded horizon slots in the action expert
|
||||
self-attention. This is useful when training across datasets with different
|
||||
control frequencies. The LeRobot port currently targets single-dataset
|
||||
fine-tuning, so `policy.chunk_size` overrides the checkpoint
|
||||
`max_action_horizon` and horizon masking is not implemented yet. Support for
|
||||
this mixed-horizon path is planned.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
39
docs/source/policy_molmoact2_README.md
Normal file
39
docs/source/policy_molmoact2_README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# MolmoAct2
|
||||
|
||||
This repository contains the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
|
||||
training, evaluation, checkpointing, and dataset compatibility.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## LIBERO Evaluation
|
||||
|
||||
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
|
||||
scene stabilize and can degrade measured success. All LIBERO evaluation results
|
||||
reported for this LeRobot implementation use `num_steps_wait=50`.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
232
docs/source/policy_vla_jepa_README.md
Normal file
232
docs/source/policy_vla_jepa_README.md
Normal file
@@ -0,0 +1,232 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
186
docs/source/rebot_b601.mdx
Normal file
186
docs/source/rebot_b601.mdx
Normal file
@@ -0,0 +1,186 @@
|
||||
# reBot B601-DM
|
||||
|
||||
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
|
||||
|
||||
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
|
||||
alt="reBot B601-DM follower arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
|
||||
alt="reBot Arm 102 leader arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the reBot support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[rebot]"
|
||||
```
|
||||
|
||||
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
|
||||
|
||||
## Registered device types
|
||||
|
||||
| Type | Kind |
|
||||
| ------------------------ | -------------------------------------------- |
|
||||
| `rebot_b601_follower` | single-arm B601-DM follower robot |
|
||||
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
|
||||
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
|
||||
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
|
||||
|
||||
The bimanual types compose two single-arm instances and namespace each arm's
|
||||
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
|
||||
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
|
||||
|
||||
## Find the USB ports
|
||||
|
||||
For each device, find the USB port associated with its motor bus using:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
|
||||
leader's USB serial port. You may also need to grant access to the serial
|
||||
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
|
||||
</Tip>
|
||||
|
||||
## Calibration
|
||||
|
||||
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
|
||||
|
||||
### Follower (B601-DM)
|
||||
|
||||
<hfoptions id="calibrate-follower">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao
|
||||
```
|
||||
|
||||
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader (reBot Arm 102)
|
||||
|
||||
<hfoptions id="calibrate-leader">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperation
|
||||
|
||||
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
The bimanual leader and follower reuse the single-arm classes; each arm is
|
||||
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
|
||||
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip>
|
||||
The leader and follower share the same joint names (`shoulder_pan,
|
||||
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
|
||||
leader actions map directly onto the follower.
|
||||
</Tip>
|
||||
|
||||
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
|
||||
```
|
||||
|
||||
## Recording datasets
|
||||
|
||||
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
|
||||
|
||||
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `vcodec` | `--dataset.camera_encoder.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `30` | Max buffered frames per camera (~1s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -48,7 +48,7 @@ This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 30 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
@@ -82,15 +82,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -100,15 +100,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -146,7 +146,7 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.camera_encoder.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
210
docs/source/tools.mdx
Normal file
210
docs/source/tools.mdx
Normal file
@@ -0,0 +1,210 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
177
docs/source/topreward.mdx
Normal file
177
docs/source/topreward.mdx
Normal file
@@ -0,0 +1,177 @@
|
||||
# TOPReward
|
||||
|
||||
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
|
||||
|
||||
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
|
||||
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
|
||||
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
|
||||
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Overview
|
||||
|
||||
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
|
||||
|
||||
- A trajectory video (a sequence of frames).
|
||||
- A task instruction (e.g. _"open the drawer"_).
|
||||
|
||||
it builds a chat prompt of the form
|
||||
|
||||
```text
|
||||
<video>
|
||||
"The above video shows a robot manipulation trajectory that completes the
|
||||
following task: <instruction> Decide whether the above statement is True
|
||||
or not. The answer is: True"
|
||||
```
|
||||
|
||||
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward.
|
||||
|
||||
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the label-masking logic. The processor owns the tokeniser and builds the full chat prompt (EO-1/Robometer pattern).
|
||||
|
||||
## What the LeRobot integration covers
|
||||
|
||||
- Standard `reward_model.type=topreward` configuration through LeRobot.
|
||||
- VLM loading via the `transformers` `Qwen3VLForConditionalGeneration` API.
|
||||
- Prompt assembly + tokenisation in the processor (matching upstream `QwenClient.compute_instruction_reward`).
|
||||
- `compute_reward()` returns one scalar log-prob per sample.
|
||||
- LeRobot reward-model save/load — `save_pretrained` writes only `config.json` (the VLM is identified by `vlm_name`).
|
||||
- An offline labeling script that writes a `topreward_progress.parquet` (SARM-compatible schema) for RA-BC and overlay.
|
||||
|
||||
The current LeRobot port supports the **Qwen3-VL client only**. Other upstream clients (Gemini, OpenAI, Gemma, Molmo) can be added as follow-up extras.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot following the [Installation Guide](./installation).
|
||||
2. Install the TOPReward optional extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[topreward]"
|
||||
```
|
||||
|
||||
or, with `uv` from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra topreward
|
||||
```
|
||||
|
||||
This pulls in `transformers`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
TOPReward expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | --------------------------- | --------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data for the task string |
|
||||
| `reward_model.max_frames` | `16` | Cap on frames per sample |
|
||||
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
|
||||
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
|
||||
|
||||
The model returns:
|
||||
|
||||
- `compute_reward(batch)`: one log-probability per sample. Higher = better task-video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the reward model directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
)
|
||||
reward_model = TOPRewardModel(cfg)
|
||||
```
|
||||
|
||||
### Use the reward factory
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"topreward",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor tokenises the full prompt (video + prefix + instruction suffix), writes Qwen-VL tensors + `prompt_length` under `observation.topreward.*`. The model reads those tensors, label-masks based on `prompt_length`, and extracts the log-prob reward.
|
||||
|
||||
### Offline dataset labeling
|
||||
|
||||
Write a `topreward_progress.parquet` for RA-BC training and overlay videos:
|
||||
|
||||
```bash
|
||||
# Sparse-dense (15 anchors per episode, matches upstream)
|
||||
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||
--dataset-repo-id lerobot/libero_10_image \
|
||||
--num-samples 15 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
Then render the progress overlay for any episode:
|
||||
|
||||
```bash
|
||||
uv run examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot/libero_10_image \
|
||||
--episode 0 \
|
||||
--progress-file topreward_progress.parquet \
|
||||
--gif
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Prompt knobs
|
||||
|
||||
The default prompt mirrors the upstream paper:
|
||||
|
||||
```text
|
||||
prompt_prefix = "The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
prompt_suffix_template = "{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
```
|
||||
|
||||
Both are exposed on `TOPRewardConfig` for ablation. The suffix template **must** contain `{instruction}`.
|
||||
|
||||
### Chat template
|
||||
|
||||
`add_chat_template=True` wraps the full prompt (including instruction) with the tokenizer's chat template before tokenisation. Default is `False`, matching the upstream paper's main experiments.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot port is **inference-only and zero-shot**; `forward()` is not overridden and `is_trainable` returns `False`.
|
||||
- Only the **Qwen3-VL family** is supported; other upstream clients are out of scope.
|
||||
- TOPReward inherits the underlying VLM's biases.
|
||||
|
||||
## References
|
||||
|
||||
- [TOPReward project page](https://topreward.github.io/webpage/)
|
||||
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
|
||||
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
|
||||
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2026topreward,
|
||||
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
|
||||
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
|
||||
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
|
||||
Krishna, Ranjay},
|
||||
journal={arXiv preprint arXiv:2602.19313},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The original TOPReward codebase is MIT-licensed. The LeRobot port follows the LeRobot Apache 2.0 license; the wrapped Qwen3-VL weights are subject to the original Qwen license.
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
--operation.camera_encoder.vcodec libsvtav1 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,11 +147,7 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
117
docs/source/video_encoding_parameters.mdx
Normal file
117
docs/source/video_encoding_parameters.mdx
Normal file
@@ -0,0 +1,117 @@
|
||||
# Video encoding parameters
|
||||
|
||||
When video storage is enabled, LeRobot stores each camera stream as an **MP4** file instead of saving one image file per timestep. Video encoding compresses across time, which usually cuts dataset size and I/O compared to a pile of PNG, while keeping MP4 — a format every player and loader understands.
|
||||
|
||||
Encoding frames into an MP4 is a full FFmpeg pipeline: choice of encoder, pixel format, GOP/keyframes, quality vs. speed, and optional extra encoder flags. Most of these knobs are user-tunable through `camera_encoder`, a nested `VideoEncoderConfig` (`lerobot.configs.video.VideoEncoderConfig`) passed through PyAV.
|
||||
|
||||
You can set these parameters from the CLI with `--dataset.camera_encoder.<field>` (e.g. with `lerobot-record` or `lerobot-rollout`). The same block applies to every camera video stream in that run.
|
||||
|
||||
<Tip>
|
||||
Video storage must be on for `camera_encoder` to have any effect —
|
||||
`use_videos=True` in Python APIs, or `--dataset.video=true` on the CLI (the
|
||||
recording default). With video off, inputs stay as images and `camera_encoder`
|
||||
is ignored.
|
||||
</Tip>
|
||||
|
||||
For details on **when** frames are written vs. encoded (streaming vs. post-episode), queues, and other top-level `--dataset.*` switches, see [Streaming Video Encoding](./streaming_video_encoding). For an encoding-parameter comparison and experiments, see the [video-benchmark Space](https://huggingface.co/spaces/lerobot/video-benchmark).
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--dataset.camera_encoder.vcodec=h264 \
|
||||
--dataset.camera_encoder.preset=fast \
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tuning parameters
|
||||
|
||||
<Tip warning={true}>
|
||||
The defaults are tuned to balance **compression ratio**, **visual quality**, and **decoding/seek speed** for typical robotics datasets. Changing them can affect both recording (CPU load, frame drops) and training (decoding throughput, image quality).
|
||||
|
||||
Only override these parameters if you have a specific reason to, and measure the impact on your pipeline before relying on the new settings.
|
||||
|
||||
</Tip>
|
||||
|
||||
All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --------------- | ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"libsvtav1"` | Video codec name. `"auto"` picks the first available hardware encoder from a fixed preference list, falling back to `libsvtav1`. |
|
||||
| `pix_fmt` | `str` | `"yuv420p"` | Output pixel format. Must be supported by the chosen codec in your FFmpeg build. |
|
||||
| `g` | `int` | `2` | GOP size — a keyframe every `g` frames. Emitted as FFmpeg option `g`. |
|
||||
| `crf` | `int` or `float` | `30` | Abstract quality value, mapped per codec (see the [mapping](#mapping-videoencoderconfig--ffmpeg-options) below). Lower → higher quality / larger output where the mapping is monotone. |
|
||||
| `preset` | `int` or `str` | `12` \* | Encoder speed preset; meaning depends on the codec. <br/>\* When unset and `vcodec=libsvtav1`, LeRobot defaults to `12`. |
|
||||
| `fast_decode` | `int` | `0` | `libsvtav1`: `0–2`, passed via `svtav1-params`. <br/>`h264` / `hevc` (software): if `>0`, sets `tune=fastdecode`. <br/>Other codecs: usually unused. |
|
||||
| `video_backend` | `str` | `"pyav"` | Only `"pyav"` is currently implemented for video encoding. |
|
||||
| `extra_options` | `dict` | `{}` | Extra FFmpeg or codec specific options merged after the structured fields above. Cannot override keys already set by those fields. |
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"observation.images.laptop": {
|
||||
"dtype": "video",
|
||||
"shape": [480, 640, 3],
|
||||
"info": {
|
||||
"video.height": 480,
|
||||
"video.width": 640,
|
||||
"video.codec": "h264",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
"video.fast_decode": 0,
|
||||
"video.video_backend": "pyav",
|
||||
"video.extra_options": { "tune": "film", "profile:v": "high", "bf": 2 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
This block is populated **once**, from the **first** episode. It assumes every
|
||||
episode in the dataset was encoded with the same `camera_encoder`. Changing
|
||||
encoder settings partway through a recording is not supported — the
|
||||
`info.json` will only reflect the parameters used for the first episode.
|
||||
</Tip>
|
||||
|
||||
---
|
||||
|
||||
## Merging datasets
|
||||
|
||||
When aggregating datasets with `merge_datasets`, video files are concatenated as-is (no re-encoding), and encoder fields in `info.json` are merged per-key:
|
||||
|
||||
- **Stream-derived fields must match** across sources: `video.codec`, `video.pix_fmt`, `video.height`, `video.width`, `video.fps`. Otherwise FFmpeg's concat demuxer fails.
|
||||
- **Encoder-tuning fields are merged loosely**: `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.extra_options`. If every source agrees, the value is kept; if not, it's set to `null` (or `{}` for `video.extra_options`) and a warning is logged.
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Dataset\n",
|
||||
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
|
||||
"HF_USER = \"your_hf_username\" # `hf auth whoami` to find your username\n",
|
||||
"DATASET_NAME = \"my_so101_dataset\"\n",
|
||||
"TASK_DESCRIPTION = \"pick and place the block\"\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
@@ -291,7 +291,34 @@
|
||||
"\n",
|
||||
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
|
||||
"\n",
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details.\n",
|
||||
"\n",
|
||||
"Recently ```lerobot-rollout``` was introduced, you can [read more about it here](https://huggingface.co/docs/lerobot/main/en/il_robots?eval=Base+mode+%28no+recording%29#run-inference-and-evaluate-your-policy)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-rollout\",\n",
|
||||
" \"--strategy.type=base\",\n",
|
||||
" f\"--policy.path={POLICY_PATH}\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f'--task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--duration=60\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"if you are using the V0.5.1 release you should use ```lerobot-record``` instead of rollout"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -138,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
@@ -151,6 +153,8 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
|
||||
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -174,6 +178,9 @@ unitree_g1 = [
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
@@ -191,6 +198,7 @@ wallx = [
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
@@ -204,9 +212,11 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -260,16 +270,19 @@ all = [
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[rebot]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[diffusion]",
|
||||
"lerobot[multi_task_dit]",
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -280,6 +293,7 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -395,8 +409,11 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
"arange",
|
||||
"is_compileable",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE"
|
||||
"OT_VALUE",
|
||||
"VanderBilt"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -199,12 +199,13 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
set_fourcc_after_size_and_fps = platform.system() == "Windows"
|
||||
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
|
||||
self._validate_fourcc()
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
|
||||
@@ -222,6 +223,11 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
|
||||
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
|
||||
# Set FOURCC last to make sure the requested pixel format is actually enforced.
|
||||
self._validate_fourcc()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -31,6 +32,12 @@ from .types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
@@ -43,7 +50,16 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
]
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
@@ -55,10 +57,9 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
|
||||
@@ -18,8 +18,8 @@ from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
|
||||
from . import parser
|
||||
from .default import EvalConfig
|
||||
from .policies import PreTrainedConfig
|
||||
|
||||
|
||||
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
else:
|
||||
del config_data[field]
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
206
src/lerobot/configs/recipe.py
Normal file
206
src/lerobot/configs/recipe.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -27,12 +27,13 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .types import PolicyFeature
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,9 +90,9 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
def get_optimizer_preset(self) -> OptimizerConfig | None:
|
||||
"""Default optimizer for this reward model, or ``None`` for zero-shot models."""
|
||||
return None
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
@@ -25,11 +25,11 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
235
src/lerobot/configs/video.py
Normal file
235
src/lerobot/configs/video.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Note: We subclass str so that serialization is straightforward
|
||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||
|
||||
"""Video encoder configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and the chosen video backend.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_VIDEO_CODECS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
# ``vcodec``` and ``pix_fmt`` are derived from the video stream directly.
|
||||
VIDEO_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset(
|
||||
{"g", "crf", "preset", "fast_decode", "extra_options", "video_backend"}
|
||||
)
|
||||
VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
"""Video encoder configuration.
|
||||
|
||||
Attributes:
|
||||
vcodec: Video encoder name. ``"auto"`` is resolved during
|
||||
construction (HW encoder if available, else ``libsvtav1``).
|
||||
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
|
||||
g: GOP size (keyframe interval).
|
||||
crf: Quality level — mapped to the native quality parameter of the
|
||||
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
|
||||
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
|
||||
preset: Speed/quality preset. Accepted type is per-codec.
|
||||
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
|
||||
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
|
||||
set ``tune=fastdecode``. Ignored for other codecs.
|
||||
video_backend: Python to be used for encoding. Only ``"pyav"``
|
||||
is currently supported.
|
||||
extra_options: Free-form dictionary of additional video encoder options
|
||||
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
|
||||
"""
|
||||
|
||||
vcodec: str = "libsvtav1" # TODO(CarolinePascal): rename to codec ?
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int | None = 2
|
||||
crf: int | float | None = 30
|
||||
preset: int | str | None = None
|
||||
fast_decode: int = 0
|
||||
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
|
||||
# two backends (encoding and decoding).
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
if self.preset is None and self.vcodec == "libsvtav1":
|
||||
self.preset = LIBSVTAV1_DEFAULT_PRESET
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
|
||||
for src_key, dst_field in (("video.codec", "vcodec"), ("video.pix_fmt", "pix_fmt")):
|
||||
value = video_info.get(src_key)
|
||||
if value is not None:
|
||||
kwargs[dst_field] = value
|
||||
|
||||
for field_name in VIDEO_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{field_name}")
|
||||
if value is None:
|
||||
continue
|
||||
# Persisted as ``{}`` after merges with disagreeing sources — treat as default.
|
||||
if field_name == "extra_options" and not value:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
|
||||
Args:
|
||||
encoders: List of encoder names to detect. If a string, it is converted to a list.
|
||||
Returns:
|
||||
List of available encoder names. If the video backend is not "pyav", returns an empty list.
|
||||
"""
|
||||
if self.video_backend == "pyav":
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import detect_available_encoders_pyav
|
||||
|
||||
return detect_available_encoders_pyav(encoders)
|
||||
return []
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the video encoder configuration."""
|
||||
if self.video_backend == "pyav":
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
|
||||
For ``"auto"``, the first hardware encoder in the preference list that is available is chosen; if none are available, ``libsvtav1`` is used. If the
|
||||
resolved codec (explicit or after auto-selection) is not available, raises ``ValueError``.
|
||||
|
||||
Stream-derived canonical codec names listed in :data:`VIDEO_CODECS_ALIASES` are
|
||||
rewritten to their corresponding encoder name (e.g. ``"av1"`` → ``"libsvtav1"``).
|
||||
"""
|
||||
self.vcodec = VIDEO_CODECS_ALIASES.get(self.vcodec, self.vcodec)
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
available = self.detect_available_encoders(HW_VIDEO_CODECS)
|
||||
for encoder in HW_VIDEO_CODECS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
self.vcodec = encoder
|
||||
return
|
||||
logger.warning("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.detect_available_encoders(self.vcodec):
|
||||
logger.info(f"Using video codec: {self.vcodec}")
|
||||
return
|
||||
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
|
||||
|
||||
def get_codec_options(
|
||||
self, encoder_threads: int | None = None, as_strings: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Translate the tuning fields to codec-specific options.
|
||||
|
||||
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
|
||||
|
||||
Args:
|
||||
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
|
||||
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
|
||||
For h264/hevc, this is mapped to ``threads``.
|
||||
Hardware encoders ignore this parameter.
|
||||
as_strings: If ``True``, casts values to strings.
|
||||
"""
|
||||
opts: dict[str, Any] = {}
|
||||
|
||||
def set_if(key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
opts[key] = value if not as_strings else str(value)
|
||||
|
||||
# GOP size is not a codec-specific option, so it is always set.
|
||||
set_if("g", self.g)
|
||||
|
||||
if self.vcodec == "libsvtav1":
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
svtav1_parts: list[str] = []
|
||||
if self.fast_decode is not None:
|
||||
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
|
||||
if encoder_threads is not None:
|
||||
svtav1_parts.append(f"lp={encoder_threads}")
|
||||
if svtav1_parts:
|
||||
opts["svtav1-params"] = ":".join(svtav1_parts)
|
||||
elif self.vcodec in ("h264", "hevc"):
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
if self.fast_decode:
|
||||
opts["tune"] = "fastdecode"
|
||||
set_if("threads", encoder_threads)
|
||||
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
if self.crf is not None:
|
||||
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
|
||||
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
opts["rc"] = 0
|
||||
set_if("qp", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "h264_vaapi":
|
||||
set_if("qp", self.crf)
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
|
||||
# Extra options are merged last but never override structured fields (values are kept as given).
|
||||
for k, v in self.extra_options.items():
|
||||
if k not in opts:
|
||||
set_if(k, v)
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
@@ -31,15 +31,25 @@ from .dataset_tools import (
|
||||
modify_features,
|
||||
modify_tasks,
|
||||
recompute_stats,
|
||||
reencode_dataset,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
@@ -53,12 +63,19 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"check_video_encoder_parameters_pyav",
|
||||
"detect_available_encoders_pyav",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
@@ -66,6 +83,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
@@ -74,6 +92,7 @@ __all__ = [
|
||||
"modify_features",
|
||||
"modify_tasks",
|
||||
"recompute_stats",
|
||||
"reencode_dataset",
|
||||
"remove_feature",
|
||||
"resolve_delta_timestamps",
|
||||
"safe_stop_image_writer",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -23,9 +24,11 @@ import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
|
||||
from .compute_stats import aggregate_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .feature_utils import get_hf_features_from_features
|
||||
from .feature_utils import features_equal_for_merge, get_hf_features_from_features
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
@@ -46,11 +49,54 @@ from .utils import (
|
||||
from .video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]:
|
||||
"""Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to merge.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of merged video feature info.
|
||||
"""
|
||||
merged_info = copy.deepcopy(all_metadata[0].features)
|
||||
video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"]
|
||||
|
||||
for vk in video_keys:
|
||||
video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata]
|
||||
base_video_info = video_infos[0]
|
||||
|
||||
merged_encoder_info: dict = {}
|
||||
fallback_keys: list[str] = []
|
||||
for info_key in VIDEO_ENCODER_INFO_KEYS:
|
||||
values = [info.get(info_key, None) for info in video_infos]
|
||||
first_value = values[0]
|
||||
all_match = all(v == first_value for v in values[1:])
|
||||
|
||||
if all_match:
|
||||
merged_encoder_info[info_key] = first_value
|
||||
else:
|
||||
fallback_keys.append(info_key)
|
||||
merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None
|
||||
|
||||
if fallback_keys:
|
||||
logging.warning(
|
||||
f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. "
|
||||
f"Setting these keys to null: {fallback_keys}.",
|
||||
)
|
||||
|
||||
merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info}
|
||||
# TODO(CarolinePascal): make this variable once we have support for other video backends.
|
||||
merged_info[vk]["info"]["video.video_backend"] = "pyav"
|
||||
|
||||
return merged_info
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
"""Validates that all dataset metadata have consistent properties.
|
||||
|
||||
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||
compatibility when aggregating them into a single dataset.
|
||||
Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||
@@ -74,7 +120,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
if not features_equal_for_merge(features, meta.features):
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
@@ -274,7 +320,8 @@ def aggregate_datasets(
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
fps, robot_type, _ = validate_all_metadata(all_metadata)
|
||||
features = merge_video_feature_info_for_aggregate(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
@@ -332,7 +379,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -414,9 +460,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
# TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
compatibility_check=True,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -24,6 +24,7 @@ import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||
from lerobot.utils.feature_utils import _validate_feature_names
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
@@ -35,12 +36,12 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
check_version_compatibility,
|
||||
@@ -176,7 +177,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -342,6 +342,49 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@tools.setter
|
||||
def tools(self, value: list[dict] | None) -> None:
|
||||
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
|
||||
|
||||
Writes ``value`` into the on-disk ``info.json`` (or clears the
|
||||
``tools`` key when ``value`` is ``None`` or empty), then reloads
|
||||
``self.info`` so the in-memory metadata matches what's on disk.
|
||||
Saves callers from hand-editing ``info.json`` and re-instantiating
|
||||
the metadata object.
|
||||
"""
|
||||
self.info.tools = [dict(t) for t in value] if value else None
|
||||
write_info(self.info, self.root)
|
||||
self.info = load_info(self.root)
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -534,10 +577,23 @@ class LeRobotDatasetMetadata:
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
@@ -546,7 +602,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -657,7 +713,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -26,7 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -36,6 +36,7 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -60,9 +61,14 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import encode_video_frames, get_video_info
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -95,6 +101,11 @@ def delete_episodes(
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
Video segments that need re-encoding (because the source file mixes kept and
|
||||
deleted episodes) are re-encoded with the source dataset's existing encoder
|
||||
settings — read back from ``meta/info.json`` — so the output dataset stays
|
||||
consistent with its own metadata.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
@@ -157,6 +168,11 @@ def split_dataset(
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
Video segments that need re-encoding (because the source file mixes episodes
|
||||
that fall into different splits) are re-encoded with the source dataset's
|
||||
existing encoder settings — read back from ``meta/info.json`` — so each
|
||||
output split stays consistent with its own metadata.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
@@ -578,8 +594,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -593,8 +608,7 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
camera_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
"""
|
||||
from fractions import Fraction
|
||||
|
||||
@@ -619,12 +633,13 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
v_out = out.add_stream(vcodec, rate=fps_fraction)
|
||||
codec_options = camera_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = pix_fmt
|
||||
v_out.pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -687,14 +702,14 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> dict[int, dict]:
|
||||
"""Copy and filter video files, only re-encoding files with deleted episodes.
|
||||
|
||||
For video files that only contain kept episodes, we copy them directly.
|
||||
For files with mixed kept/deleted episodes, we use PyAV filters to efficiently
|
||||
re-encode only the desired segments.
|
||||
re-encode only the desired segments. The encoder used for re-encoding is
|
||||
derived per video key from the source dataset's ``meta/info.json`` so the
|
||||
destination metadata keeps describing the videos accurately.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset to copy from
|
||||
@@ -711,6 +726,9 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
if dst_meta.video_path is None:
|
||||
raise ValueError("Destination metadata has no video_path defined")
|
||||
@@ -792,8 +810,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
camera_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1264,11 +1281,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1282,11 +1295,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1322,11 +1331,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1644,11 +1649,7 @@ def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1663,11 +1664,8 @@ def convert_image_to_video_dataset(
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
camera_encoder: Video encoder settings
|
||||
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
@@ -1676,6 +1674,9 @@ def convert_image_to_video_dataset(
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
@@ -1699,7 +1700,10 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
)
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1769,11 +1773,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1815,11 +1815,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1865,7 +1861,9 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
@@ -1888,3 +1886,83 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
return video_path
|
||||
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Re-encode every video in a dataset with a new set of encoding parameters.
|
||||
|
||||
Videos are re-encoded in-place and the video information in ``info.json`` is refreshed.
|
||||
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
sequential (no multiprocessing); ``1+`` spawns a
|
||||
:class:`~concurrent.futures.ProcessPoolExecutor`.
|
||||
|
||||
Returns:
|
||||
The same :class:`LeRobotDataset` instance with its metadata updated
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
for future in tqdm(
|
||||
as_completed(futures),
|
||||
total=len(futures),
|
||||
desc="Re-encoding videos",
|
||||
):
|
||||
future.result()
|
||||
else:
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -31,6 +31,8 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .feature_utils import (
|
||||
@@ -65,14 +67,19 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
@@ -89,20 +96,22 @@ class DatasetWriter:
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
vcodec: str,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
@@ -111,7 +120,7 @@ class DatasetWriter:
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._vcodec = vcodec
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -241,7 +250,14 @@ class DatasetWriter:
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
stacked_values = np.stack(episode_buffer[key])
|
||||
|
||||
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||
stacked_values = stacked_values.reshape(episode_length)
|
||||
|
||||
episode_buffer[key] = stacked_values
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
@@ -284,7 +300,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._vcodec,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -495,7 +511,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key)
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -564,7 +580,12 @@ class DatasetWriter:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
|
||||
@@ -13,15 +13,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -46,7 +54,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -108,6 +122,41 @@ def create_empty_dataset_info(
|
||||
)
|
||||
|
||||
|
||||
def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool:
|
||||
"""Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation.
|
||||
|
||||
For video features, keys under ``info`` related to video encoding parameters are ignored during
|
||||
comparison as they do not prevent aggregation.
|
||||
"""
|
||||
|
||||
def _without_encoder_info_keys(feature: dict) -> dict:
|
||||
filtered = dict(feature)
|
||||
filtered_info = filtered.get("info")
|
||||
if isinstance(filtered_info, dict):
|
||||
filtered["info"] = {
|
||||
info_key: info_value
|
||||
for info_key, info_value in filtered_info.items()
|
||||
if info_key not in VIDEO_ENCODER_INFO_KEYS
|
||||
}
|
||||
return filtered
|
||||
|
||||
if set(features_a) != set(features_b):
|
||||
return False
|
||||
for key in features_a:
|
||||
fa_key = features_a[key]
|
||||
fb_key = features_b[key]
|
||||
if fa_key.get("dtype") != fb_key.get("dtype"):
|
||||
return False
|
||||
if fa_key.get("dtype") != "video":
|
||||
if fa_key != fb_key:
|
||||
return False
|
||||
continue
|
||||
|
||||
if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
@@ -242,6 +291,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return validate_feature_language(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
@@ -321,6 +372,30 @@ def validate_feature_string(name: str, value: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def validate_feature_language(name: str, value) -> str:
|
||||
"""Validate a feature that is expected to hold language annotations.
|
||||
|
||||
Language columns (``language_persistent`` / ``language_events``) are
|
||||
populated after recording by the annotation pipeline, not at record time.
|
||||
Any value supplied here is dropped before the frame is written, so a
|
||||
non-empty value almost certainly signals a mistake. We warn rather than
|
||||
fail to keep recording resilient.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
str: Always an empty string — language values are non-fatal.
|
||||
"""
|
||||
if value is not None:
|
||||
logging.warning(
|
||||
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
|
||||
f"not at record time. The provided value will be dropped."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -186,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -265,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -304,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
242
src/lerobot/datasets/language.py
Normal file
242
src/lerobot/datasets/language.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
|
||||
uses for frame data.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float32(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float32"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
545
src/lerobot/datasets/language_render.py
Normal file
545
src/lerobot/datasets/language_render.py
Normal file
@@ -0,0 +1,545 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
|
||||
is also a float32 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz). This does mean two persistent
|
||||
rows of the same selector emitted within 0.1 s of each other cannot be
|
||||
told apart by ``emitted_at`` — acceptable because persistent annotations
|
||||
(subtask / plan / memory transitions) change on a human-action timescale,
|
||||
not at the camera frame rate."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
return float(unwrap_scalar(row["timestamp"]))
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -24,6 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -36,8 +37,7 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
get_safe_default_video_backend,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,10 +59,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -183,16 +183,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
|
||||
Note:
|
||||
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
|
||||
@@ -207,10 +206,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -273,12 +271,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
vcodec=self._vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -320,17 +321,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@staticmethod
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
camera_encoder=camera_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -647,7 +644,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -678,20 +675,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend (used when reading back).
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
|
||||
``'h264'``, ``'hevc'``, ``'auto'``.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
streaming_encoding: If ``True``, encode video frames in real-time
|
||||
during capture instead of writing images first.
|
||||
encoder_queue_maxsize: Max buffered frames per camera when using
|
||||
streaming encoding.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` in write mode.
|
||||
"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -712,23 +709,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -751,12 +748,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Resume recording on an existing dataset.
|
||||
|
||||
@@ -779,13 +776,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend for reading back data.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
vcodec: Video codec for encoding.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
image_writer_threads: Threads for async image writing.
|
||||
streaming_encoding: If ``True``, encode video in real-time during
|
||||
capture.
|
||||
encoder_queue_maxsize: Max buffered frames per camera for streaming.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
@@ -796,7 +795,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root)
|
||||
@@ -805,11 +803,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -818,21 +814,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer for appending
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
174
src/lerobot/datasets/pyav_utils.py
Normal file
174
src/lerobot/datasets/pyav_utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
|
||||
|
||||
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
|
||||
Checks degrade to a no-op when the target codec isn't available locally.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
try:
|
||||
return av.codec.Codec(vcodec, "w")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return {}
|
||||
return {opt.name: opt for opt in codec.descriptor.options}
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
|
||||
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return ()
|
||||
return tuple(fmt.name for fmt in (codec.video_formats or []))
|
||||
|
||||
|
||||
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
|
||||
|
||||
Each name is probed directly via :func:`get_codec`; input order is preserved.
|
||||
"""
|
||||
if isinstance(encoders, str):
|
||||
encoders = [encoders]
|
||||
|
||||
available: list[str] = []
|
||||
for name in encoders:
|
||||
codec = get_codec(name)
|
||||
if codec is not None and codec.type == "video":
|
||||
available.append(name)
|
||||
else:
|
||||
logger.debug("encoder '%s' not available as video encoder", name)
|
||||
return available
|
||||
|
||||
|
||||
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
|
||||
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
|
||||
type_name = opt.type.name
|
||||
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
num_val = float(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
|
||||
# Check integer type compatibility
|
||||
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
|
||||
raise ValueError(
|
||||
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
|
||||
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
|
||||
)
|
||||
|
||||
# Check numeric range compatibility
|
||||
lo, hi = float(opt.min), float(opt.max)
|
||||
if lo < hi and not (lo <= num_val <= hi):
|
||||
raise ValueError(
|
||||
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
|
||||
)
|
||||
|
||||
elif type_name == "STRING":
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
|
||||
if isinstance(value, str):
|
||||
str_val = value
|
||||
elif isinstance(value, (int, float)):
|
||||
str_val = str(value)
|
||||
else:
|
||||
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
|
||||
|
||||
# Check string choice compatibility
|
||||
choices = [c.name for c in (opt.choices or [])]
|
||||
if choices and str_val not in choices:
|
||||
raise ValueError(
|
||||
f"{label}={str_val!r} is not a supported choice for codec "
|
||||
f"{vcodec!r}; valid choices: {choices}"
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
formats = _get_codec_video_formats(vcodec)
|
||||
if formats and pix_fmt not in formats:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
|
||||
f"supported pixel formats: {list(formats)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
for key, value in codec_options.items():
|
||||
# GOP size is not a codec-specific option, it has to be validated separately.
|
||||
if key == "g":
|
||||
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
|
||||
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
|
||||
continue
|
||||
if key not in supported_options:
|
||||
continue
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
ValueError: on the first incompatibility encountered.
|
||||
"""
|
||||
options = _get_codec_options_by_name(vcodec)
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
@@ -88,7 +88,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -130,6 +129,9 @@ class DatasetInfo:
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
@@ -151,11 +153,15 @@ class DatasetInfo:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -17,12 +17,14 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@@ -36,86 +38,14 @@ import torch
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.configs import (
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
@@ -143,7 +73,7 @@ def decode_video_frames(
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "pyav":
|
||||
@@ -263,15 +193,70 @@ def decode_video_frames_pyav(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
DEFAULT_DECODER_CACHE_SIZE = 100
|
||||
"""Default LRU capacity for :class:`VideoDecoderCache`.
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
Sized to comfortably hold a small rolling window of episodes worth of decoders
|
||||
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
|
||||
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
|
||||
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
|
||||
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
|
||||
to the constructor (``None`` restores the legacy unbounded behaviour).
|
||||
"""
|
||||
|
||||
|
||||
def _default_max_cache_size() -> int | None:
|
||||
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
|
||||
if raw is None:
|
||||
return DEFAULT_DECODER_CACHE_SIZE
|
||||
raw = raw.strip().lower()
|
||||
if raw in ("", "none", "unbounded", "-1"):
|
||||
return None
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
|
||||
) from e
|
||||
if value <= 0:
|
||||
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
|
||||
return value
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
|
||||
|
||||
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
|
||||
backing it. When the cache is full and a new path is requested, the
|
||||
least-recently-used entry is evicted and its file handle is closed. This
|
||||
bounds host-RAM growth when iterating over datasets with many distinct
|
||||
video files (otherwise each ``DataLoader`` worker pins every decoder it has
|
||||
ever opened until the process exits).
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of decoders to retain. ``None`` disables
|
||||
eviction and restores legacy unbounded behaviour. Defaults to the
|
||||
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
|
||||
:data:`DEFAULT_DECODER_CACHE_SIZE`.
|
||||
"""
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
raise ValueError(f"max_size must be positive or None; got {max_size}")
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
return str(video_path) in self._cache
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
@@ -283,22 +268,36 @@ class VideoDecoderCache:
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
return self._cache[video_path][0]
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
# Evict LRU entries until we are back under the cap. We close
|
||||
# evicted file handles immediately; the associated ``VideoDecoder``
|
||||
# is released to the GC when its last reference goes away.
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
return decoder
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
"""Clear the cache and close all file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
with contextlib.suppress(Exception):
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
@@ -407,18 +406,17 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -429,42 +427,18 @@ def encode_video_frames(
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -500,8 +474,97 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping re-encode.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
if log_level is not None:
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
try:
|
||||
with av.open(input_video_path, mode="r") as src:
|
||||
try:
|
||||
in_stream = src.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise ValueError(f"No video stream in {input_video_path}") from e
|
||||
|
||||
fps = (
|
||||
in_stream.base_rate
|
||||
) # We allow fractional fps though LeRobotDataset only supports integer fps
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
options={
|
||||
"movflags": "faststart"
|
||||
}, # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
) as dst:
|
||||
out_stream = dst.add_stream(vcodec, fps, options=video_options)
|
||||
out_stream.pix_fmt = pix_fmt
|
||||
out_stream.width = width
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
packet = out_stream.encode()
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
except Exception:
|
||||
Path(tmp_output_video_path).unlink(missing_ok=True)
|
||||
raise
|
||||
finally:
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_video_path.exists():
|
||||
raise OSError(f"Video re-encoding did not work. File not found: {output_video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
overwrite: bool = True,
|
||||
compatibility_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
@@ -514,6 +577,7 @@ def concatenate_video_files(
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
compatibility_check: Whether to check if the input videos are compatible. Default is False.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
@@ -532,6 +596,22 @@ def concatenate_video_files(
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# This check may be skipped at recording time as videos are encoded with the same encoder config.
|
||||
if compatibility_check:
|
||||
reference_video_info = get_video_info(input_video_paths[0])
|
||||
for input_path in input_video_paths[1:]:
|
||||
video_info = get_video_info(input_path)
|
||||
if (
|
||||
video_info["video.height"] != reference_video_info["video.height"]
|
||||
or video_info["video.width"] != reference_video_info["video.width"]
|
||||
or video_info["video.fps"] != reference_video_info["video.fps"]
|
||||
or video_info["video.codec"] != reference_video_info["video.codec"]
|
||||
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
|
||||
)
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
@@ -598,26 +678,20 @@ class _CameraEncoderThread(threading.Thread):
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
codec_options: dict[str, str],
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.codec_options = codec_options
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -653,19 +727,9 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
@@ -731,22 +795,24 @@ class StreamingVideoEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
When ``None``, :func:`camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
"""
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -777,18 +843,17 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -993,8 +1058,18 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
@@ -1025,6 +1100,14 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
|
||||
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -44,6 +57,7 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -43,6 +43,7 @@ from .tables import (
|
||||
CAN_CMD_SET_ZERO,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
HANDSHAKE_TIMEOUT_S,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||
def _query_status_via_clear_fault(
|
||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||
) -> tuple[bool, can.Message | None]:
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||
|
||||
def _recv_status_via_clear_fault(
|
||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
faulted_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
elif has_fault:
|
||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
|
||||
return responses
|
||||
|
||||
def _recv_all_messages_until_quiet(
|
||||
self,
|
||||
*,
|
||||
timeout: float = RUNNING_TIMEOUT,
|
||||
max_messages: int = 4096,
|
||||
) -> list[can.Message]:
|
||||
"""
|
||||
Receive frames until the bus goes quiet.
|
||||
|
||||
Args:
|
||||
timeout: Poll timeout used for each recv() call. Collection stops
|
||||
when one recv() times out (quiet gap).
|
||||
max_messages: Safety cap to prevent unbounded loops.
|
||||
"""
|
||||
out: list[can.Message] = []
|
||||
max_messages = max(1, max_messages)
|
||||
timeout = max(0.0, timeout)
|
||||
|
||||
try:
|
||||
while len(out) < max_messages:
|
||||
msg = self._bus().recv(timeout=timeout)
|
||||
if msg is None:
|
||||
break
|
||||
out.append(msg)
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||
|
||||
return out
|
||||
|
||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||
"""
|
||||
Decode all received feedback frames and update cached motor states.
|
||||
|
||||
Returns:
|
||||
Set of payload recv_ids that were successfully mapped to motors.
|
||||
"""
|
||||
processed_recv_ids: set[int] = set()
|
||||
for msg in messages:
|
||||
if len(msg.data) < 1:
|
||||
logger.debug(
|
||||
f"Dropping short CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
recv_id = int(msg.data[0])
|
||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||
if motor_name is None:
|
||||
logger.debug(
|
||||
f"Unmapped CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
self._process_response(motor_name, msg)
|
||||
processed_recv_ids.add(recv_id)
|
||||
|
||||
return processed_recv_ids
|
||||
|
||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||
"""
|
||||
Drain pending RX frames from the CAN interface.
|
||||
|
||||
This is used by higher-level controllers to drop stale feedback before issuing
|
||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||
It should also be called once when a controller instance is created/connected,
|
||||
to clear residual frames left on the interface from previous sessions.
|
||||
"""
|
||||
drained = 0
|
||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||
max_messages = max(1, max_messages)
|
||||
try:
|
||||
while drained < max_messages:
|
||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||
if msg is None:
|
||||
break
|
||||
drained += 1
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||
return drained
|
||||
|
||||
def _speed_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||
# This avoids dropping useful frames when responses from different motors interleave.
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
try:
|
||||
self._decode_motor_state(msg.data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to decode response from {motor} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||
)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the state cache."""
|
||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._bus().send(msg)
|
||||
updated_motors.append(motor)
|
||||
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||
|
||||
for response in responses.values():
|
||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||
if payload_motor_name is not None:
|
||||
self._process_response(payload_motor_name, response)
|
||||
else:
|
||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||
self._decode_motor_state(response.data)
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
for motor in updated_motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if recv_id not in responses:
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
|
||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
RUNNING_TIMEOUT = 0.003
|
||||
HANDSHAKE_TIMEOUT_S = 0.05
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
|
||||
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
@@ -43,6 +44,7 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
|
||||
@@ -28,11 +28,12 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_eo1 import EO1Config
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -44,6 +43,8 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
|
||||
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
@@ -56,6 +57,7 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -88,7 +90,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
|
||||
"molmoact2".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -151,6 +154,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "molmoact2":
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -168,7 +179,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x".
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -203,6 +214,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -231,6 +246,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
dataset_meta: Any | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -406,6 +422,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -414,6 +431,23 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MolmoAct2Config):
|
||||
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
processors = make_molmoact2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
@@ -499,6 +533,10 @@ def make_policy(
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
if ds_meta is not None:
|
||||
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
|
||||
if callable(set_dataset_feature_metadata):
|
||||
set_dataset_feature_metadata(ds_meta.features)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||
"or call ensure_eagle_cache_ready() before building processors."
|
||||
)
|
||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
||||
proc = AutoProcessor.from_pretrained(
|
||||
str(cache_dir),
|
||||
trust_remote_code=True,
|
||||
fix_mistral_regex=False,
|
||||
)
|
||||
proc.tokenizer.padding_side = "left"
|
||||
return proc
|
||||
|
||||
|
||||
1
src/lerobot/policies/molmoact2/README.md
Symbolic link
1
src/lerobot/policies/molmoact2/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_molmoact2_README.md
|
||||
21
src/lerobot/policies/molmoact2/__init__.py
Normal file
21
src/lerobot/policies/molmoact2/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_molmoact2 import MolmoAct2Config
|
||||
from .modeling_molmoact2 import MolmoAct2Policy
|
||||
from .processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]
|
||||
519
src/lerobot/policies/molmoact2/configuration_molmoact2.py
Normal file
519
src/lerobot/policies/molmoact2/configuration_molmoact2.py
Normal file
@@ -0,0 +1,519 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import (
|
||||
AdamWConfig,
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
LRSchedulerConfig,
|
||||
OptimizerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from ..rtc.configuration_rtc import RTCConfig
|
||||
|
||||
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
|
||||
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
|
||||
MOLMOACT2_TASK_TOKEN_BUDGET = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
|
||||
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_checkpoint_location(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
checkpoint_path = str(checkpoint_path or "").strip()
|
||||
if not checkpoint_path:
|
||||
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||
local_path = Path(checkpoint_path).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=checkpoint_path,
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
def _load_hf_norm_metadata_for_tag(
|
||||
checkpoint_path: str,
|
||||
*,
|
||||
revision: str | None,
|
||||
force_download: bool,
|
||||
norm_tag: str | None,
|
||||
) -> dict[str, Any]:
|
||||
norm_tag = str(norm_tag or "").strip()
|
||||
if not norm_tag:
|
||||
return {}
|
||||
checkpoint_location = Path(
|
||||
_resolve_checkpoint_location(
|
||||
checkpoint_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
norm_stats_filename = "norm_stats.json"
|
||||
config_path = checkpoint_location / "config.json"
|
||||
if config_path.exists():
|
||||
with suppress(OSError, json.JSONDecodeError):
|
||||
norm_stats_filename = str(
|
||||
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||
)
|
||||
stats_path = checkpoint_location / norm_stats_filename
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||
)
|
||||
payload = json.loads(stats_path.read_text())
|
||||
metadata_by_tag = payload.get("metadata_by_tag")
|
||||
if not isinstance(metadata_by_tag, dict):
|
||||
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||
metadata = metadata_by_tag.get(norm_tag)
|
||||
if not isinstance(metadata, dict):
|
||||
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||
return metadata
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
|
||||
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
|
||||
|
||||
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
|
||||
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
|
||||
training steps"; build() is the first point where num_training_steps is known.
|
||||
"""
|
||||
|
||||
num_decay_steps: int | None
|
||||
|
||||
def build(self, optimizer, num_training_steps: int):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.peak_lr,
|
||||
decay_lr=self.decay_lr,
|
||||
num_warmup_steps=self.num_warmup_steps,
|
||||
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
|
||||
).build(optimizer, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def _round_up(value: int, multiple: int) -> int:
|
||||
return int(math.ceil(value / multiple) * multiple)
|
||||
|
||||
|
||||
def infer_molmoact2_max_sequence_length(
|
||||
*,
|
||||
num_images: int,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
action_horizon: int,
|
||||
include_discrete_action: bool,
|
||||
) -> int:
|
||||
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
|
||||
if num_images < 1:
|
||||
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim < 0:
|
||||
state_dim = 0
|
||||
if action_dim < 1:
|
||||
action_dim = 1
|
||||
if action_horizon < 1:
|
||||
action_horizon = 1
|
||||
|
||||
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
|
||||
prompt_tokens = (
|
||||
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
|
||||
+ MOLMOACT2_TASK_TOKEN_BUDGET
|
||||
+ state_dim
|
||||
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
|
||||
)
|
||||
action_tokens = 0
|
||||
if include_discrete_action:
|
||||
action_tokens_per_step = max(
|
||||
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
|
||||
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
|
||||
)
|
||||
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
|
||||
|
||||
return _round_up(
|
||||
image_tokens + prompt_tokens + action_tokens,
|
||||
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
|
||||
)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("molmoact2")
|
||||
@dataclass
|
||||
class MolmoAct2Config(PreTrainedConfig):
|
||||
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
|
||||
|
||||
checkpoint_path: str = "allenai/MolmoAct2"
|
||||
checkpoint_revision: str | None = None
|
||||
checkpoint_force_download: bool = False
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 30
|
||||
n_action_steps: int = 30
|
||||
|
||||
action_mode: str = "both"
|
||||
inference_action_mode: str | None = None
|
||||
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
|
||||
discrete_generation_max_steps: int | None = None
|
||||
norm_tag: str | None = None
|
||||
|
||||
setup_type: str = ""
|
||||
control_mode: str = ""
|
||||
image_keys: list[str] = field(default_factory=list)
|
||||
normalize_language: bool = True
|
||||
add_setup_tokens: bool = True
|
||||
add_control_tokens: bool = True
|
||||
normalize_gripper: bool = False
|
||||
num_state_tokens: int = 256
|
||||
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
|
||||
# image/prompt/state/action token layout. Override only for unusual long prompts.
|
||||
max_sequence_length: int | None = None
|
||||
|
||||
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
|
||||
expected_max_action_dim: int = 32
|
||||
|
||||
# Flow-matching training knobs copied from the original MolmoAct2 training path.
|
||||
num_flow_timesteps: int = 8
|
||||
flow_matching_cutoff: float = 1.0
|
||||
flow_matching_time_offset: float = 0.001
|
||||
flow_matching_time_scale: float = 0.999
|
||||
flow_matching_beta_alpha: float = 1.0
|
||||
flow_matching_beta_beta: float = 1.5
|
||||
num_inference_steps: int | None = None
|
||||
mask_action_dim_padding: bool = True
|
||||
enable_inference_cuda_graph: bool = True
|
||||
# MolmoAct2-local eval option. When enabled, stochastic continuous action
|
||||
# generation uses a rollout-local generator derived from eval_seed.
|
||||
per_episode_seed: bool = False
|
||||
eval_seed: int | None = None
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
# Default is full finetuning with gradients from the action expert flowing into the VLM.
|
||||
enable_lora_vlm: bool = False
|
||||
lora_rank: int = 64
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.05
|
||||
lora_bias: str = "none"
|
||||
enable_lora_action_expert: bool = False
|
||||
enable_knowledge_insulation: bool = False
|
||||
freeze_embedding: bool = True
|
||||
train_action_expert_only: bool = False
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
model_dtype: str = "bfloat16"
|
||||
softmax_auxiliary_loss: bool = True
|
||||
softmax_auxiliary_loss_scale: float = 1e-4
|
||||
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_vit_lr: float = 5e-6
|
||||
optimizer_connector_lr: float = 5e-6
|
||||
optimizer_action_expert_lr: float = 5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-6
|
||||
optimizer_weight_decay: float = 0.0
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 200
|
||||
scheduler_decay_steps: int | None = None
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES,
|
||||
"ACTION": NormalizationMode.QUANTILES,
|
||||
}
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.action_mode not in {"continuous", "discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mode={self.action_mode!r}. "
|
||||
"Expected one of {'continuous', 'discrete', 'both'}."
|
||||
)
|
||||
if self.inference_action_mode not in {None, "continuous", "discrete"}:
|
||||
raise ValueError(
|
||||
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
|
||||
"Expected one of {None, 'continuous', 'discrete'}."
|
||||
)
|
||||
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
|
||||
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
|
||||
if self.train_action_expert_only and self.action_mode != "continuous":
|
||||
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
|
||||
if self.train_action_expert_only and self.enable_lora_vlm:
|
||||
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
|
||||
if self.enable_lora_action_expert and not self.enable_lora_vlm:
|
||||
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
|
||||
if self.chunk_size < 1:
|
||||
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
|
||||
if self.n_action_steps < 1:
|
||||
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
|
||||
)
|
||||
if self.expected_max_action_dim != 32:
|
||||
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
|
||||
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
|
||||
raise ValueError(
|
||||
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
|
||||
)
|
||||
if self.lora_rank < 1:
|
||||
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
|
||||
if self.lora_alpha < 1:
|
||||
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
|
||||
if not 0 <= self.lora_dropout <= 1:
|
||||
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
|
||||
if self.lora_bias not in {"none", "all", "lora_only"}:
|
||||
raise ValueError(
|
||||
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
|
||||
)
|
||||
if self.discrete_loss_token_weighting not in {
|
||||
"none",
|
||||
"token",
|
||||
"root_tokens",
|
||||
"root_subsegments",
|
||||
"root_subsegments_root_tokens",
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
|
||||
)
|
||||
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
|
||||
raise ValueError(
|
||||
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
|
||||
)
|
||||
if self.max_sequence_length is not None and self.max_sequence_length < 1:
|
||||
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
|
||||
|
||||
def inferred_max_sequence_length(
|
||||
self,
|
||||
*,
|
||||
num_images: int | None = None,
|
||||
state_dim: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
action_horizon: int | None = None,
|
||||
include_discrete_action: bool | None = None,
|
||||
) -> int:
|
||||
if self.max_sequence_length is not None:
|
||||
return int(self.max_sequence_length)
|
||||
|
||||
if num_images is None:
|
||||
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||
if state_dim is None:
|
||||
state_feature = self.robot_state_feature
|
||||
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
|
||||
if action_dim is None:
|
||||
action_feature = self.action_feature
|
||||
action_dim = (
|
||||
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
|
||||
)
|
||||
if action_horizon is None:
|
||||
action_horizon = self.chunk_size
|
||||
if include_discrete_action is None:
|
||||
include_discrete_action = self.action_mode in {"discrete", "both"}
|
||||
|
||||
return infer_molmoact2_max_sequence_length(
|
||||
num_images=int(num_images),
|
||||
state_dim=int(state_dim),
|
||||
action_dim=int(action_dim),
|
||||
action_horizon=int(action_horizon),
|
||||
include_discrete_action=bool(include_discrete_action),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
|
||||
self.dataset_feature_names = {}
|
||||
for key in (ACTION, OBS_STATE):
|
||||
feature = features.get(key) if isinstance(features, dict) else None
|
||||
if isinstance(feature, dict) and feature.get("names") is not None:
|
||||
self.dataset_feature_names[key] = feature["names"]
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up MolmoAct2 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"MolmoAct2 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(0,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.expected_max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def apply_norm_tag_metadata(self) -> None:
|
||||
if not str(self.norm_tag or "").strip():
|
||||
return
|
||||
metadata = _load_hf_norm_metadata_for_tag(
|
||||
self.checkpoint_path,
|
||||
revision=self.checkpoint_revision,
|
||||
force_download=bool(self.checkpoint_force_download),
|
||||
norm_tag=self.norm_tag,
|
||||
)
|
||||
if metadata.get("action_horizon") is not None:
|
||||
self.chunk_size = int(metadata["action_horizon"])
|
||||
if metadata.get("n_action_steps") is not None:
|
||||
self.n_action_steps = int(metadata["n_action_steps"])
|
||||
if not self.setup_type and metadata.get("setup_type") is not None:
|
||||
self.setup_type = str(metadata["setup_type"])
|
||||
if not self.control_mode and metadata.get("control_mode") is not None:
|
||||
self.control_mode = str(metadata["control_mode"])
|
||||
|
||||
def saved_policy_action_mode(self) -> str | None:
|
||||
pretrained_path = getattr(self, "pretrained_path", None)
|
||||
if pretrained_path is None:
|
||||
return None
|
||||
config_path = Path(pretrained_path) / "config.json"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
try:
|
||||
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
if mode in {"continuous", "discrete", "both"}:
|
||||
return str(mode)
|
||||
return None
|
||||
|
||||
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
|
||||
return saved_policy_action_mode or self.action_mode
|
||||
|
||||
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
return
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||
"continuous inference."
|
||||
)
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError(
|
||||
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||
)
|
||||
|
||||
def validate_checkpoint_action_mode(
|
||||
self,
|
||||
checkpoint_action_mode: str,
|
||||
*,
|
||||
has_action_expert: bool,
|
||||
) -> None:
|
||||
if self.action_mode == "both" and checkpoint_action_mode != "both":
|
||||
raise ValueError(
|
||||
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||
raise ValueError(
|
||||
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||
f"got {checkpoint_action_mode!r}."
|
||||
)
|
||||
if self.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||
|
||||
def resolve_inference_action_mode(
|
||||
self,
|
||||
requested_mode: str | None,
|
||||
saved_policy_action_mode: str | None = None,
|
||||
) -> str:
|
||||
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||
if requested_mode is None:
|
||||
requested_mode = self.inference_action_mode
|
||||
if requested_mode is None:
|
||||
raise ValueError(
|
||||
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||
"to either 'continuous' or 'discrete'."
|
||||
)
|
||||
if requested_mode not in {"continuous", "discrete"}:
|
||||
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||
if requested_mode == "continuous" and training_mode == "discrete":
|
||||
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||
if requested_mode == "discrete" and training_mode == "continuous":
|
||||
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||
return requested_mode
|
||||
17
src/lerobot/policies/molmoact2/hf_model/__init__.py
Normal file
17
src/lerobot/policies/molmoact2/hf_model/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
237
src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py
Normal file
237
src/lerobot/policies/molmoact2/hf_model/action_tokenizer.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
def _hf_token() -> str | None:
|
||||
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def _resolve_tokenizer_location(
|
||||
tokenizer_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
local_path = Path(str(tokenizer_path)).expanduser()
|
||||
if local_path.exists():
|
||||
return str(local_path)
|
||||
return snapshot_download(
|
||||
repo_id=str(tokenizer_path),
|
||||
repo_type="model",
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||
token=_hf_token(),
|
||||
)
|
||||
|
||||
|
||||
class UniversalActionProcessor(ProcessorMixin):
|
||||
attributes: ClassVar[list[str]] = ["tokenizer"]
|
||||
tokenizer_class: str = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
min_token: int = 0,
|
||||
*,
|
||||
action_dim: int | None = None,
|
||||
time_horizon: int | None = None,
|
||||
):
|
||||
self.scale = scale
|
||||
self.vocab_size = vocab_size
|
||||
self.min_token = min_token
|
||||
|
||||
# Action horizon and dimension needed during decoding. These can be specified
|
||||
# in three ways (in order of priority):
|
||||
# 1. passed in as kwargs to decode()
|
||||
# 2. in the constructor
|
||||
# 3. cached from the last time decode() was called
|
||||
self.time_horizon = time_horizon
|
||||
self.action_dim = action_dim
|
||||
self.called_time_horizon = time_horizon
|
||||
self.called_action_dim = action_dim
|
||||
|
||||
super().__init__(tokenizer)
|
||||
self.bpe_tokenizer = self.tokenizer
|
||||
|
||||
def __call__(self, action_chunk: np.array) -> np.array:
|
||||
from scipy.fft import dct
|
||||
|
||||
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk[None, ...]
|
||||
|
||||
# Cache the time horizon and action dimension for decoding
|
||||
self.called_time_horizon = action_chunk.shape[-2]
|
||||
self.called_action_dim = action_chunk.shape[-1]
|
||||
|
||||
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
|
||||
dct_coeff = np.around(dct_coeff * self.scale)
|
||||
tokens = []
|
||||
for elem in dct_coeff:
|
||||
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
|
||||
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
|
||||
return tokens
|
||||
|
||||
def decode(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> np.array:
|
||||
from scipy.fft import idct
|
||||
|
||||
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
|
||||
self.action_dim = action_dim or self.action_dim or self.called_action_dim
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
)
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
@classmethod
|
||||
def fit(
|
||||
cls,
|
||||
action_data: list[np.array],
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> "UniversalActionProcessor":
|
||||
from scipy.fft import dct
|
||||
|
||||
# Run DCT over all inputs
|
||||
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
|
||||
|
||||
# Quantize and find min token
|
||||
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
|
||||
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
|
||||
min_vocab_size = max_token - min_token
|
||||
|
||||
assert min_vocab_size <= vocab_size, (
|
||||
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
|
||||
)
|
||||
if min_vocab_size + 100 > vocab_size:
|
||||
logging.warning(
|
||||
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
|
||||
f"size {vocab_size}, consider increasing vocab size"
|
||||
)
|
||||
|
||||
# Make token iterator for BPE training
|
||||
def _token_iter():
|
||||
for tokens in dct_tokens:
|
||||
rounded_tokens = np.around(tokens * scale) - min_token
|
||||
rounded_tokens = rounded_tokens.astype(int)
|
||||
string = "".join(map(chr, rounded_tokens))
|
||||
yield string
|
||||
|
||||
# Train BPE tokenizer
|
||||
bpe = ByteLevelBPETokenizer()
|
||||
|
||||
# Set up the entire range of possible tokens as the initial alphabet
|
||||
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
min_frequency=2,
|
||||
show_progress=True,
|
||||
special_tokens=[],
|
||||
initial_alphabet=alphabet,
|
||||
max_token_length=10000,
|
||||
)
|
||||
|
||||
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
|
||||
# because it doesn't support custom alphabets)
|
||||
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
|
||||
|
||||
return cls(
|
||||
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
|
||||
scale=scale,
|
||||
vocab_size=vocab_size,
|
||||
min_token=min_token,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_local(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
) -> "UniversalActionProcessor":
|
||||
location = Path(
|
||||
_resolve_tokenizer_location(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
processor_config = {}
|
||||
processor_config_path = location / "processor_config.json"
|
||||
if processor_config_path.exists():
|
||||
import json
|
||||
|
||||
processor_config = json.loads(processor_config_path.read_text())
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
|
||||
return cls(
|
||||
tokenizer,
|
||||
scale=processor_config.get("scale", 10),
|
||||
vocab_size=processor_config.get("vocab_size", 1024),
|
||||
min_token=processor_config.get("min_token", 0),
|
||||
action_dim=processor_config.get("action_dim"),
|
||||
time_horizon=processor_config.get("time_horizon"),
|
||||
)
|
||||
@@ -0,0 +1,553 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
MolmoAct2 configuration
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MolmoAct2VitConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
||||
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
||||
defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig
|
||||
>>> configuration = MolmoAct2VitConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
||||
>>> model = MolmoAct2VisionTransformer(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
base_config_key = "vit_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
num_hidden_layers: int = 27,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 72,
|
||||
hidden_act: str = "gelu_pytorch_tanh",
|
||||
layer_norm_eps: float = 1e-6,
|
||||
image_default_input_size: tuple[int, int] = (378, 378),
|
||||
image_patch_size: int = 14,
|
||||
image_num_pos: int = 577,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
float32_attention: bool = True,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_default_input_size = image_default_input_size
|
||||
self.image_patch_size = image_patch_size
|
||||
self.image_num_pos = image_num_pos
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.float32_attention = float32_attention
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
h, w = self.image_default_input_size
|
||||
return h // self.image_patch_size, w // self.image_patch_size
|
||||
|
||||
|
||||
class MolmoAct2AdapterConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
||||
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
||||
defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
||||
>>> vit_config = MolmoAct2VitConfig()
|
||||
>>> adapter_config = MolmoPoolingConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
||||
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> vit_configuration = model.vit_config
|
||||
>>> adapter_configuration = model.adapter_config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
base_config_key = "adapter_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_layers: tuple = (-3, -9),
|
||||
pooling_attention_mask: bool = False,
|
||||
hidden_size: int = 1152,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 72,
|
||||
float32_attention: bool = True,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
hidden_act: str = "silu",
|
||||
intermediate_size: int = 18944,
|
||||
text_hidden_size: int = 3584,
|
||||
image_feature_dropout: float = 0.0,
|
||||
initializer_range: float = 0.02,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||
self.vit_layers = vit_layers
|
||||
self.pooling_attention_mask = pooling_attention_mask
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.float32_attention = float32_attention
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.text_hidden_size = text_hidden_size
|
||||
self.image_feature_dropout = image_feature_dropout
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class MolmoAct2TextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
||||
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
||||
|
||||
>>> # Initializing a MolmoAct2TextConfig
|
||||
>>> configuration = MolmoAct2TextConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
||||
>>> model = MolmoAct2TextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"blocks.*.self_attn.att_proj": "colwise",
|
||||
"blocks.*.self_attn.attn_out": "rowwise",
|
||||
"blocks.*.mlp.ff_proj": "colwise",
|
||||
"blocks.*.mlp.ff_out": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"wte": (["input_ids"], ["inputs_embeds"]),
|
||||
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"ln_f": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 3584,
|
||||
num_attention_heads: int = 28,
|
||||
num_key_value_heads: int | None = 4,
|
||||
head_dim: int = 128,
|
||||
vocab_size: int = 152064,
|
||||
additional_vocab_size: int = 128,
|
||||
qkv_bias: bool = True,
|
||||
num_hidden_layers: int = 48,
|
||||
intermediate_size: int = 18944,
|
||||
hidden_act: str = "silu",
|
||||
embedding_dropout: float = 0.0,
|
||||
attention_dropout: float = 0.0,
|
||||
residual_dropout: float = 0.0,
|
||||
max_position_embeddings: int = 4096,
|
||||
rope_theta: float = 1000000.0,
|
||||
rope_scaling: dict[str, Any] = None,
|
||||
rope_scaling_layers: list[int] | None = None,
|
||||
use_qk_norm: bool = False,
|
||||
qk_norm_type: str = "olmo",
|
||||
layer_norm_eps: int = 1e-6,
|
||||
norm_after: bool = False,
|
||||
initializer_range: float = 0.02,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation: str = "eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_implementation = attn_implementation
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.additional_vocab_size = additional_vocab_size
|
||||
self.qkv_bias = qkv_bias
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_scaling_layers = rope_scaling_layers
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.qk_norm_type = qk_norm_type
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.norm_after = norm_after
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
||||
r"""Configuration for the MolmoAct2 modern action expert."""
|
||||
|
||||
model_type = "molmoact2_action_expert"
|
||||
base_config_key = "action_expert_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_action_horizon: int = 32,
|
||||
max_action_dim: int = 32,
|
||||
hidden_size: int = 1024,
|
||||
num_layers: int = 32,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 8.0 / 3.0,
|
||||
ffn_multiple_of: int = 256,
|
||||
timestep_embed_dim: int = 256,
|
||||
dropout: float = 0.0,
|
||||
attn_dropout: float = 0.0,
|
||||
context_layer_norm: bool = True,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_eps: float = 1e-6,
|
||||
rope: bool = True,
|
||||
causal_attn: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.max_action_horizon = max_action_horizon
|
||||
self.max_action_dim = max_action_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.ffn_multiple_of = ffn_multiple_of
|
||||
self.timestep_embed_dim = timestep_embed_dim
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = attn_dropout
|
||||
self.context_layer_norm = context_layer_norm
|
||||
self.qk_norm = qk_norm
|
||||
self.qk_norm_eps = qk_norm_eps
|
||||
self.rope = rope
|
||||
self.causal_attn = causal_attn
|
||||
|
||||
def to_dict(self):
|
||||
output = super().to_dict()
|
||||
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
|
||||
# them out of the public nested config avoids duplicated sources of truth.
|
||||
output.pop("max_action_horizon", None)
|
||||
output.pop("max_action_dim", None)
|
||||
return output
|
||||
|
||||
|
||||
class MolmoAct2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
||||
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
||||
|
||||
>>> # Initializing a MolmoAct2VitConfig
|
||||
>>> vit_config = MolmoAct2VitConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2AdapterConfig
|
||||
>>> adapter_config = MolmoAct2AdapterConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2TextConfig
|
||||
>>> text_config = MolmoAct2TextConfig()
|
||||
|
||||
>>> # Initializing a MolmoAct2Config
|
||||
>>> configuration = MolmoAct2Config(
|
||||
>>> vit_config=vit_config,
|
||||
>>> adapter_config=adapter_config,
|
||||
>>> text_config=text_config,
|
||||
>>> image_start_token_id=151936,
|
||||
>>> image_end_token_id=151937,
|
||||
>>> image_patch_id=151938,
|
||||
>>> image_col_id=151939,
|
||||
>>> low_res_image_start_token_id=151940,
|
||||
>>> image_low_res_id=151942,
|
||||
>>> frame_start_token_id=151943,
|
||||
>>> frame_end_token_id=151944,
|
||||
>>> )
|
||||
|
||||
>>> # Initializing a model
|
||||
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "molmoact2"
|
||||
sub_configs = {
|
||||
"text_config": MolmoAct2TextConfig,
|
||||
"vit_config": MolmoAct2VitConfig,
|
||||
"adapter_config": MolmoAct2AdapterConfig,
|
||||
"action_expert_config": MolmoAct2ActionExpertConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_config: MolmoAct2VitConfig = None,
|
||||
adapter_config: MolmoAct2AdapterConfig = None,
|
||||
text_config: MolmoAct2TextConfig = None,
|
||||
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
||||
image_start_token_id: int = None,
|
||||
low_res_image_start_token_id: int = None,
|
||||
image_end_token_id: int = None,
|
||||
image_low_res_id: int = None,
|
||||
image_patch_id: int = None,
|
||||
image_col_id: int = None,
|
||||
frame_start_token_id: int = None,
|
||||
frame_end_token_id: int = None,
|
||||
use_frame_special_tokens: bool = True,
|
||||
initializer_range: float = 0.02,
|
||||
add_action_expert: bool = True,
|
||||
max_action_dim: int = 32,
|
||||
max_action_horizon: int = 30,
|
||||
n_obs_steps: int = 30,
|
||||
action_mode: str = "both",
|
||||
state_format: str = "discrete",
|
||||
flow_matching_num_steps: int = 10,
|
||||
flow_matching_cutoff: float = 1.0,
|
||||
flow_matching_time_offset: float = 0.001,
|
||||
flow_matching_time_scale: float = 0.999,
|
||||
flow_matching_beta_alpha: float = 1.0,
|
||||
flow_matching_beta_beta: float = 1.5,
|
||||
mask_action_dim_padding: bool = True,
|
||||
enable_depth_reasoning: bool = False,
|
||||
depth_mode: int = 2,
|
||||
num_depth_codes: int = 100,
|
||||
action_expert_depth_gate: bool = False,
|
||||
action_expert_depth_gate_per_layer: bool = False,
|
||||
action_expert_depth_gate_init_bias: float = -4.0,
|
||||
action_output_token_id: int = None,
|
||||
action_start_token_id: int = None,
|
||||
action_end_token_id: int = None,
|
||||
action_token_start_id: int = None,
|
||||
num_action_tokens: int = 0,
|
||||
depth_output_token_id: int = None,
|
||||
depth_start_token_id: int = None,
|
||||
depth_end_token_id: int = None,
|
||||
depth_token_start_id: int = None,
|
||||
num_depth_tokens: int = 0,
|
||||
state_start_token_id: int = None,
|
||||
state_end_token_id: int = None,
|
||||
state_token_start_id: int = None,
|
||||
num_state_tokens: int = 0,
|
||||
add_setup_tokens: bool = True,
|
||||
add_control_tokens: bool = True,
|
||||
norm_stats_filename: str = "norm_stats.json",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if vit_config is None:
|
||||
self.vit_config = MolmoAct2VitConfig()
|
||||
elif isinstance(vit_config, dict):
|
||||
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
||||
else:
|
||||
self.vit_config = vit_config
|
||||
if adapter_config is None:
|
||||
self.adapter_config = MolmoAct2AdapterConfig()
|
||||
elif isinstance(adapter_config, dict):
|
||||
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
||||
else:
|
||||
self.adapter_config = adapter_config
|
||||
if text_config is None:
|
||||
self.text_config = MolmoAct2TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
self.text_config = MolmoAct2TextConfig(**text_config)
|
||||
else:
|
||||
self.text_config = text_config
|
||||
self.add_action_expert = bool(add_action_expert)
|
||||
if not self.add_action_expert:
|
||||
self.action_expert_config = None
|
||||
elif action_expert_config is None:
|
||||
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
||||
max_action_horizon=max_action_horizon,
|
||||
max_action_dim=max_action_dim,
|
||||
num_layers=self.text_config.num_hidden_layers,
|
||||
)
|
||||
elif isinstance(action_expert_config, dict):
|
||||
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
||||
else:
|
||||
self.action_expert_config = action_expert_config
|
||||
if self.add_action_expert:
|
||||
self.action_expert_config.max_action_dim = int(max_action_dim)
|
||||
self.action_expert_config.max_action_horizon = int(max_action_horizon)
|
||||
self._validate_release_action_config(
|
||||
state_format=state_format,
|
||||
)
|
||||
self.image_start_token_id = image_start_token_id
|
||||
self.low_res_image_start_token_id = low_res_image_start_token_id
|
||||
self.image_end_token_id = image_end_token_id
|
||||
self.image_low_res_id = image_low_res_id
|
||||
self.image_high_res_id = image_patch_id
|
||||
self.image_patch_id = image_patch_id
|
||||
self.image_col_id = image_col_id
|
||||
self.frame_start_token_id = frame_start_token_id
|
||||
self.frame_end_token_id = frame_end_token_id
|
||||
self.use_frame_special_tokens = use_frame_special_tokens
|
||||
self.initializer_range = initializer_range
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_action_horizon = max_action_horizon
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.action_mode = action_mode
|
||||
self.state_format = state_format
|
||||
self.flow_matching_num_steps = flow_matching_num_steps
|
||||
self.flow_matching_cutoff = flow_matching_cutoff
|
||||
self.flow_matching_time_offset = flow_matching_time_offset
|
||||
self.flow_matching_time_scale = flow_matching_time_scale
|
||||
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
||||
self.flow_matching_beta_beta = flow_matching_beta_beta
|
||||
self.mask_action_dim_padding = mask_action_dim_padding
|
||||
self.enable_depth_reasoning = enable_depth_reasoning
|
||||
self.depth_mode = depth_mode
|
||||
self.num_depth_codes = num_depth_codes
|
||||
self.action_expert_depth_gate = action_expert_depth_gate
|
||||
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
||||
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
||||
self.action_output_token_id = action_output_token_id
|
||||
self.action_start_token_id = action_start_token_id
|
||||
self.action_end_token_id = action_end_token_id
|
||||
self.action_token_start_id = action_token_start_id
|
||||
self.num_action_tokens = num_action_tokens
|
||||
self.depth_output_token_id = depth_output_token_id
|
||||
self.depth_start_token_id = depth_start_token_id
|
||||
self.depth_end_token_id = depth_end_token_id
|
||||
self.depth_token_start_id = depth_token_start_id
|
||||
self.num_depth_tokens = num_depth_tokens
|
||||
self.state_start_token_id = state_start_token_id
|
||||
self.state_end_token_id = state_end_token_id
|
||||
self.state_token_start_id = state_token_start_id
|
||||
self.num_state_tokens = num_state_tokens
|
||||
self.add_setup_tokens = add_setup_tokens
|
||||
self.add_control_tokens = add_control_tokens
|
||||
self.norm_stats_filename = norm_stats_filename
|
||||
|
||||
@staticmethod
|
||||
def _validate_release_action_config(
|
||||
*,
|
||||
state_format: str,
|
||||
) -> None:
|
||||
if state_format != "discrete":
|
||||
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
assert self.vit_config is not None
|
||||
return self.vit_config.image_num_patch
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.text_config.num_attention_heads
|
||||
|
||||
@property
|
||||
def num_key_value_heads(self):
|
||||
return self.text_config.num_key_value_heads
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.text_config.head_dim
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.text_config.num_hidden_layers
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.text_config.hidden_size
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.text_config.vocab_size
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.text_config.max_position_embeddings
|
||||
|
||||
|
||||
MolmoAct2VitConfig.register_for_auto_class()
|
||||
MolmoAct2AdapterConfig.register_for_auto_class()
|
||||
MolmoAct2TextConfig.register_for_auto_class()
|
||||
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
||||
MolmoAct2Config.register_for_auto_class()
|
||||
@@ -0,0 +1,564 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Image processor class for MolmoAct2"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import numpy as np
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_flat_list_of_images,
|
||||
valid_images,
|
||||
to_numpy_array,
|
||||
)
|
||||
from transformers.image_transforms import convert_to_rgb
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def normalize_image(
|
||||
image: np.ndarray,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
) -> np.ndarray:
|
||||
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
desired_output_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
dtype = image.dtype
|
||||
if torch.is_floating_point(image):
|
||||
in_min = 0.0
|
||||
in_max = 1.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||
|
||||
resized = resized.to(torch.float32)
|
||||
resized = (resized - in_min) / (in_max - in_min)
|
||||
|
||||
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def select_tiling(h, w, patch_size, max_num_crops):
|
||||
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||
original_size = np.stack([h, w]) # [1, 2]
|
||||
original_res = h * w
|
||||
tilings = []
|
||||
for i in range(1, max_num_crops + 1):
|
||||
for j in range(1, max_num_crops + 1):
|
||||
if i * j <= max_num_crops:
|
||||
tilings.append((i, j))
|
||||
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
||||
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
||||
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
||||
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
||||
|
||||
# How much we would need to scale the image to fit exactly in each tiling
|
||||
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
||||
|
||||
# The original size can be zero in rare cases if the image is smaller than the margin
|
||||
# In those cases letting the scale become infinite means the tiling is based on the
|
||||
# other side, or falls back to the smallest tiling
|
||||
with np.errstate(divide="ignore"):
|
||||
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
|
||||
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
||||
if np.all(required_scale < 1):
|
||||
# We are forced to downscale, so try to minimize the amount of downscaling
|
||||
ix = np.argmax(required_scale)
|
||||
else:
|
||||
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
||||
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
||||
ix = np.argmin(required_scale)
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
def build_resized_image(
|
||||
image: np.ndarray,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
resized = resize_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
)
|
||||
resized = normalize_image(resized, image_mean, image_std)
|
||||
if len(resized.shape) == 3:
|
||||
resized = np.expand_dims(resized, 0)
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||
return resized, resize_idx
|
||||
|
||||
|
||||
def build_overlapping_crops(
|
||||
image: np.ndarray,
|
||||
max_crops: int,
|
||||
overlap_margins: list[int],
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Decompose an image into a set of overlapping crops
|
||||
|
||||
:return crop_arr: [n_crops, h, w, 3] The crops
|
||||
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
||||
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
||||
"""
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
assert base_image_input_size[0] == base_image_input_size[1]
|
||||
|
||||
left_margin, right_margin = overlap_margins
|
||||
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
||||
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
||||
crop_window_size = crop_window_patches * image_patch_size
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
original_image_h, original_image_w = image.shape[:2]
|
||||
crop_size = base_image_input_size[0]
|
||||
|
||||
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
||||
# as if we had an image without the margins and were using a crop size without the margins
|
||||
tiling = select_tiling(
|
||||
original_image_h - total_margin_pixels,
|
||||
original_image_w - total_margin_pixels,
|
||||
crop_window_size,
|
||||
max_crops,
|
||||
)
|
||||
|
||||
src = resize_image(
|
||||
image,
|
||||
[
|
||||
tiling[0] * crop_window_size + total_margin_pixels,
|
||||
tiling[1] * crop_window_size + total_margin_pixels,
|
||||
],
|
||||
resample,
|
||||
)
|
||||
src = normalize_image(src, image_mean, image_std)
|
||||
|
||||
# Now we have to split the image into crops, and track what patches came from
|
||||
# where in `patch_idx_arr`
|
||||
n_crops = tiling[0] * tiling[1]
|
||||
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
||||
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
||||
on_crop = 0
|
||||
for i in range(tiling[0]):
|
||||
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
||||
# which results in overlapping crop windows
|
||||
y0 = i * crop_window_size
|
||||
for j in range(tiling[1]):
|
||||
x0 = j * crop_window_size
|
||||
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
|
||||
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
||||
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
||||
|
||||
# Mask out idx that are in the overlap region
|
||||
if i != 0:
|
||||
patch_idx[:left_margin, :] = -1
|
||||
if j != 0:
|
||||
patch_idx[:, :left_margin] = -1
|
||||
if i != tiling[0] - 1:
|
||||
patch_idx[-right_margin:, :] = -1
|
||||
if j != tiling[1] - 1:
|
||||
patch_idx[:, -right_margin:] = -1
|
||||
patch_idx_arr[on_crop] = patch_idx
|
||||
on_crop += 1
|
||||
|
||||
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
||||
# so it is ordered left-to-right order
|
||||
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
|
||||
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
||||
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
||||
|
||||
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
||||
# to the correct patch it should come from in `crop_arr`
|
||||
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
||||
src.shape[0] // image_patch_size,
|
||||
src.shape[1] // image_patch_size,
|
||||
)
|
||||
return crop_arr, patch_idx_arr
|
||||
|
||||
|
||||
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||
if len(array.shape) == 3:
|
||||
n_crops, h, w = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||
return array
|
||||
else:
|
||||
n_crops, h, w, c = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||
return array
|
||||
|
||||
|
||||
def arange_for_pooling(
|
||||
idx_arr: np.ndarray,
|
||||
pool_h: int,
|
||||
pool_w: int,
|
||||
) -> np.ndarray:
|
||||
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||
idx_arr = np.pad(
|
||||
idx_arr,
|
||||
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||
mode="constant",
|
||||
constant_values=-1,
|
||||
)
|
||||
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||
|
||||
|
||||
def image_to_patches_and_grids(
|
||||
image: np.ndarray,
|
||||
max_crops: int,
|
||||
overlap_margins: list[int],
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
image_pooling_w: int,
|
||||
image_pooling_h: int,
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
||||
:return crops, the image crops to processes with the ViT
|
||||
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||
patches in `crops` to pool for that token, masked with -1
|
||||
"""
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
base_image_input_d = image_patch_size
|
||||
pooling_w = image_pooling_w
|
||||
pooling_h = image_pooling_h
|
||||
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
||||
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
||||
|
||||
if crop_mode == "resize":
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
resized_h, resized_w = resize_idx.shape[:2]
|
||||
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
||||
return (
|
||||
np.stack(image_grid, 0),
|
||||
batch_pixels_to_patches(resized, image_patch_size),
|
||||
resize_idx,
|
||||
)
|
||||
|
||||
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
||||
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
||||
|
||||
crop_arr, patch_idx_arr = build_overlapping_crops(
|
||||
image,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
||||
h, w = pooling_idx.shape[:2]
|
||||
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||
|
||||
# Finally do the same for the global image
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
crop_arr = np.concatenate([resized, crop_arr], 0)
|
||||
|
||||
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
resized_h, resized_w = resize_idx.shape[:2]
|
||||
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||
|
||||
# Global image goes first, so the order of patches in previous crops gets increased
|
||||
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
|
||||
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
||||
image_grid = [np.array([resized_h, resized_w, h, w])]
|
||||
|
||||
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
|
||||
|
||||
|
||||
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
||||
max_crops: int | None
|
||||
overlap_margins: list[int] | None
|
||||
crop_mode: str | None
|
||||
patch_size: int | None
|
||||
pooling_size: list[int] | None
|
||||
|
||||
|
||||
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
||||
|
||||
Args:
|
||||
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use when resizing the image.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
max_crops (`int`, *optional*, defaults to `8`):
|
||||
Maximum number of crops to use per image.
|
||||
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
||||
Overlap margins to use.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
||||
The pooling size of the vision adapter.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool = True,
|
||||
max_crops: int = 8,
|
||||
overlap_margins: list[int] = [4, 4],
|
||||
crop_mode: str = "overlap-and-resize-c2",
|
||||
patch_size: int = 14,
|
||||
pooling_size: list[int] = [2, 2],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 378, "width": 378}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
self.size = size
|
||||
|
||||
self.resample = resample
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
self.max_crops = max_crops
|
||||
self.overlap_margins = overlap_margins
|
||||
self.crop_mode = crop_mode
|
||||
self.patch_size = patch_size
|
||||
self.pooling_size = pooling_size
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
size: dict[str, int] | None = None,
|
||||
resample: PILImageResampling | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
max_crops: int | None = None,
|
||||
overlap_margins: list[int] | None = None,
|
||||
crop_mode: str | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_size: list[int] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess.
|
||||
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
||||
Maximum number of crops to use per image.
|
||||
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
||||
Overlap margins to use.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||
The pooling size of the vision adapter.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
A `BatchFeature` containing the following keys:
|
||||
- `pixel_values`: The preprocessed images.
|
||||
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
||||
- `image_grids`: The image grids.
|
||||
- `image_num_crops`: The number of crops for each image.
|
||||
"""
|
||||
if size is not None:
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
else:
|
||||
size = {**self.size}
|
||||
|
||||
base_image_input_size = [size["height"], size["width"]]
|
||||
|
||||
resample = resample or self.resample
|
||||
image_mean = image_mean or self.image_mean
|
||||
image_std = image_std or self.image_std
|
||||
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||
|
||||
max_crops = max_crops or self.max_crops
|
||||
overlap_margins = overlap_margins or self.overlap_margins
|
||||
crop_mode = crop_mode or self.crop_mode
|
||||
patch_size = patch_size or self.patch_size
|
||||
pooling_size = pooling_size or self.pooling_size
|
||||
|
||||
image_pooling_h, image_pooling_w = pooling_size
|
||||
|
||||
if images is not None:
|
||||
images = self.fetch_images(images)
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
batch_grids = []
|
||||
batch_crops = []
|
||||
batch_pooled_patches_idx = []
|
||||
batch_num_crops = []
|
||||
|
||||
for image in images:
|
||||
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||
image,
|
||||
max_crops,
|
||||
overlap_margins,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
patch_size,
|
||||
image_pooling_w,
|
||||
image_pooling_h,
|
||||
crop_mode,
|
||||
)
|
||||
batch_grids.append(image_grid)
|
||||
batch_crops.append(crops)
|
||||
batch_pooled_patches_idx.append(pooled_idx)
|
||||
batch_num_crops.append(crops.shape[0])
|
||||
|
||||
pixel_values = np.concatenate(batch_crops, 0)
|
||||
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
image_grids = np.concatenate(batch_grids, 0)
|
||||
image_num_crops = np.array(batch_num_crops)
|
||||
|
||||
data.update(
|
||||
pixel_values=pixel_values,
|
||||
image_token_pooling=image_token_pooling,
|
||||
image_grids=image_grids,
|
||||
image_num_crops=image_num_crops,
|
||||
)
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
MolmoAct2ImageProcessor.register_for_auto_class()
|
||||
748
src/lerobot/policies/molmoact2/hf_model/inference.py
Normal file
748
src/lerobot/policies/molmoact2/hf_model/inference.py
Normal file
@@ -0,0 +1,748 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Inference utilities for MolmoAct2"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionFlowInputs:
|
||||
trajectory: torch.Tensor
|
||||
context: Any
|
||||
modulations: Sequence[Any]
|
||||
action_dim_is_pad: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ActionFlowCudaGraph:
|
||||
key: tuple[Any, ...]
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_inputs: _ActionFlowInputs
|
||||
output: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphLayerStage:
|
||||
residual: torch.Tensor
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphPostStage:
|
||||
graph: torch.cuda.CUDAGraph
|
||||
attn_context: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraph:
|
||||
cache_key: tuple[Any, ...]
|
||||
pre_graph: torch.cuda.CUDAGraph
|
||||
token_ids: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
||||
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
||||
output: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DepthDecodeCudaGraphSpec:
|
||||
eligible: bool
|
||||
cache_key_prefix: tuple[Any, ...]
|
||||
num_hidden_layers: int
|
||||
head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
|
||||
if past_key_values is None:
|
||||
return 0
|
||||
seq_len = past_key_values.get_seq_length()
|
||||
if torch.is_tensor(seq_len):
|
||||
return int(seq_len.item())
|
||||
return int(seq_len)
|
||||
|
||||
|
||||
def _cache_max_len_int(past_key_values: Cache | None) -> int:
|
||||
if past_key_values is None:
|
||||
return -1
|
||||
max_len = past_key_values.get_max_cache_shape()
|
||||
if torch.is_tensor(max_len):
|
||||
return int(max_len.item())
|
||||
return int(max_len)
|
||||
|
||||
|
||||
def _iter_cache_key_values(
|
||||
past_key_values: Cache,
|
||||
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
|
||||
layers = getattr(past_key_values, "layers", None)
|
||||
if layers is not None:
|
||||
for layer in layers:
|
||||
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
||||
return
|
||||
for layer in past_key_values:
|
||||
yield layer[0], layer[1]
|
||||
|
||||
|
||||
class _DepthDecodeStaticLayerCache:
|
||||
is_compileable = False
|
||||
is_sliding = False
|
||||
|
||||
def __init__(self, max_cache_len: int) -> None:
|
||||
self.max_cache_len = int(max_cache_len)
|
||||
self.cumulative_length = 0
|
||||
self.keys: torch.Tensor | None = None
|
||||
self.values: torch.Tensor | None = None
|
||||
|
||||
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
||||
bsz, n_heads = key_states.shape[:2]
|
||||
self.keys = torch.empty(
|
||||
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
||||
dtype=key_states.dtype,
|
||||
device=key_states.device,
|
||||
)
|
||||
self.values = torch.empty(
|
||||
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
||||
dtype=value_states.dtype,
|
||||
device=value_states.device,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.keys is None:
|
||||
self._allocate(key_states, value_states)
|
||||
start = self.cumulative_length
|
||||
end = start + key_states.shape[-2]
|
||||
if end > self.max_cache_len:
|
||||
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
|
||||
self.keys[:, :, start:end, :].copy_(key_states)
|
||||
self.values[:, :, start:end, :].copy_(value_states)
|
||||
self.cumulative_length = end
|
||||
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
||||
|
||||
def get_seq_length(self) -> int:
|
||||
return self.cumulative_length
|
||||
|
||||
def get_max_cache_shape(self) -> int:
|
||||
return -1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.cumulative_length = 0
|
||||
|
||||
|
||||
class _DepthDecodeStaticCache(Cache):
|
||||
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
super().__init__(
|
||||
layers=[
|
||||
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
||||
for _ in range(text_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def get_seq_length(self, layer_idx: int = 0) -> int:
|
||||
return self.layers[layer_idx].get_seq_length()
|
||||
|
||||
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
||||
return self.layers[layer_idx].get_max_cache_shape()
|
||||
|
||||
def reset(self) -> None:
|
||||
for layer in self.layers:
|
||||
layer.reset()
|
||||
|
||||
|
||||
class ActionCudaGraphManager:
|
||||
def __init__(self, model: Any) -> None:
|
||||
self.model = model
|
||||
self.enabled = True
|
||||
self.action_flow_graph: _ActionFlowCudaGraph | None = None
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self.enabled = bool(enabled)
|
||||
|
||||
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
||||
action_model = self.model
|
||||
if not self.enabled:
|
||||
return False
|
||||
if action_model.training or action_model._require_action_expert().training:
|
||||
return False
|
||||
if inputs.trajectory.device.type != "cuda":
|
||||
return False
|
||||
|
||||
def all_on_cuda():
|
||||
yield inputs.trajectory
|
||||
for k, v in inputs.context.kv_contexts:
|
||||
yield k
|
||||
yield v
|
||||
for t in (
|
||||
inputs.context.cross_mask,
|
||||
inputs.context.self_mask,
|
||||
inputs.context.valid_action,
|
||||
inputs.action_dim_is_pad,
|
||||
):
|
||||
if t is not None:
|
||||
yield t
|
||||
if inputs.context.rope_cache is not None:
|
||||
yield from inputs.context.rope_cache
|
||||
for step in inputs.modulations:
|
||||
yield step.conditioning
|
||||
for block_modulation in step.block_modulations:
|
||||
yield from block_modulation
|
||||
yield from step.final_modulation
|
||||
|
||||
return all(t.device.type == "cuda" for t in all_on_cuda())
|
||||
|
||||
def run_action_flow(
|
||||
self,
|
||||
inputs: _ActionFlowInputs,
|
||||
steps: int,
|
||||
run_loop,
|
||||
) -> torch.Tensor:
|
||||
key = _cuda_graph_key(inputs, steps)
|
||||
cache = self.action_flow_graph
|
||||
if cache is None or cache.key != key:
|
||||
static_inputs = _clone_static_inputs(inputs)
|
||||
graph, output = _capture_cuda_graph(
|
||||
lambda: run_loop(static_inputs, steps),
|
||||
inputs.trajectory.device,
|
||||
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
||||
)
|
||||
cache = _ActionFlowCudaGraph(
|
||||
key=key,
|
||||
graph=graph,
|
||||
static_inputs=static_inputs,
|
||||
output=output,
|
||||
)
|
||||
self.action_flow_graph = cache
|
||||
else:
|
||||
_copy_inputs_(cache.static_inputs, inputs)
|
||||
|
||||
cache.graph.replay()
|
||||
return cache.output.clone()
|
||||
|
||||
|
||||
class DepthDecodeCudaGraphManager:
|
||||
def __init__(self, model: Any) -> None:
|
||||
self.model = model
|
||||
self.backbone = model.model
|
||||
self.enabled = True
|
||||
self.graph: _DepthDecodeCudaGraph | None = None
|
||||
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
|
||||
|
||||
def set_enabled(self, enabled: bool) -> None:
|
||||
self.enabled = bool(enabled)
|
||||
|
||||
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
||||
return _DepthDecodeStaticCache(
|
||||
config=self.model.config.text_config,
|
||||
max_cache_len=max_cache_len,
|
||||
)
|
||||
|
||||
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
||||
static = self.graph_spec
|
||||
if static is None:
|
||||
cfg = self.backbone.transformer.config
|
||||
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
||||
static = _DepthDecodeCudaGraphSpec(
|
||||
eligible=(
|
||||
not cfg.norm_after
|
||||
and cfg.rope_scaling_layers is None
|
||||
and getattr(rotary_emb, "rope_type", None) == "default"
|
||||
and cfg._attn_implementation == "sdpa"
|
||||
),
|
||||
cache_key_prefix=(
|
||||
cfg.hidden_size,
|
||||
cfg.num_attention_heads,
|
||||
cfg.num_key_value_heads,
|
||||
cfg.head_dim,
|
||||
cfg.num_hidden_layers,
|
||||
cfg.use_qk_norm,
|
||||
cfg.qk_norm_type,
|
||||
cfg._attn_implementation,
|
||||
),
|
||||
num_hidden_layers=cfg.num_hidden_layers,
|
||||
head_dim=cfg.head_dim,
|
||||
num_attention_heads=cfg.num_attention_heads,
|
||||
)
|
||||
self.graph_spec = static
|
||||
return static
|
||||
|
||||
def can_use(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> bool:
|
||||
if not self.enabled or self.model.training or self.backbone.transformer.training:
|
||||
return False
|
||||
if next_input_ids.device.type != "cuda":
|
||||
return False
|
||||
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
|
||||
return False
|
||||
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
||||
return False
|
||||
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
|
||||
return False
|
||||
return self._depth_decode_spec().eligible
|
||||
|
||||
def _depth_decode_key(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> tuple[Any, ...]:
|
||||
device = next_input_ids.device
|
||||
return (
|
||||
self._depth_decode_spec().cache_key_prefix,
|
||||
device.type,
|
||||
device.index,
|
||||
self.model.lm_head.weight.dtype,
|
||||
attention_bias.shape[-1],
|
||||
)
|
||||
|
||||
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
|
||||
emb = self.backbone.transformer.rotary_emb
|
||||
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
||||
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
||||
|
||||
def _depth_decode_pre_layer(
|
||||
self,
|
||||
layer_idx: int,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
block = self.backbone.transformer.blocks[layer_idx]
|
||||
attention = block.self_attn
|
||||
residual = hidden_states
|
||||
hidden_states = block.attn_norm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, attention.head_dim)
|
||||
qkv = attention.att_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
||||
value_states = value_states.view(hidden_shape)
|
||||
|
||||
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
||||
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
||||
|
||||
if apply_qk_norm and not norm_after_view:
|
||||
query_states = attention.q_norm(query_states)
|
||||
key_states = attention.k_norm(key_states)
|
||||
|
||||
query_states = query_states.view(hidden_shape)
|
||||
key_states = key_states.view(hidden_shape)
|
||||
|
||||
if norm_after_view:
|
||||
query_states = attention.q_norm(query_states)
|
||||
key_states = attention.k_norm(key_states)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
return residual, query_states, key_states, value_states
|
||||
|
||||
def _depth_decode_pre0(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
||||
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
||||
|
||||
def _depth_decode_post_layer(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
block = self.backbone.transformer.blocks[layer_idx]
|
||||
attention = block.self_attn
|
||||
input_shape = residual.shape[:-1]
|
||||
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = attention.attn_out(attn_output)
|
||||
hidden_states = residual + block.dropout(attn_output)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = block.ff_norm(hidden_states)
|
||||
hidden_states = block.mlp(hidden_states)
|
||||
hidden_states = residual + block.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _depth_decode_post_and_pre_next(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
||||
|
||||
def _depth_decode_last_post(
|
||||
self,
|
||||
layer_idx: int,
|
||||
residual: torch.Tensor,
|
||||
attn_context: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||
return self.backbone.transformer.ln_f(hidden_states)
|
||||
|
||||
def _build_depth_decode_graph(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_length: int,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> _DepthDecodeCudaGraph:
|
||||
text_config = self.backbone.transformer.config
|
||||
device = next_input_ids.device
|
||||
dtype = self.model.lm_head.weight.dtype
|
||||
static = self._depth_decode_spec()
|
||||
num_layers = static.num_hidden_layers
|
||||
head_dim = static.head_dim
|
||||
max_cache_len = int(attention_bias.shape[-1])
|
||||
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
||||
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
|
||||
|
||||
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
||||
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
||||
sin = torch.empty_like(cos)
|
||||
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
||||
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
||||
|
||||
token_ids.copy_(next_input_ids)
|
||||
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
||||
|
||||
pre_graph, pre_output = _capture_cuda_graph(
|
||||
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
||||
device,
|
||||
)
|
||||
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
||||
post_graphs = []
|
||||
for layer_idx in range(num_layers - 1):
|
||||
stage = stages[-1]
|
||||
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||
graph, output = _capture_cuda_graph(
|
||||
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
||||
self._depth_decode_post_and_pre_next(
|
||||
layer_idx,
|
||||
stage.residual,
|
||||
attn_context,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
),
|
||||
device,
|
||||
)
|
||||
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
|
||||
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
||||
|
||||
last_stage = stages[-1]
|
||||
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||
last_graph, last_output = _capture_cuda_graph(
|
||||
lambda: self._depth_decode_last_post(
|
||||
num_layers - 1,
|
||||
last_stage.residual,
|
||||
last_attn_context,
|
||||
),
|
||||
device,
|
||||
)
|
||||
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
|
||||
return _DepthDecodeCudaGraph(
|
||||
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
||||
pre_graph=pre_graph,
|
||||
token_ids=token_ids,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
positions=positions,
|
||||
stages=tuple(stages),
|
||||
post_graphs=tuple(post_graphs),
|
||||
output=last_output,
|
||||
)
|
||||
|
||||
def _get_depth_decode_graph(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_length: int,
|
||||
attention_bias: torch.Tensor,
|
||||
) -> _DepthDecodeCudaGraph:
|
||||
key = self._depth_decode_key(next_input_ids, attention_bias)
|
||||
decode_graph = self.graph
|
||||
if decode_graph is None or decode_graph.cache_key != key:
|
||||
decode_graph = self._build_depth_decode_graph(
|
||||
next_input_ids,
|
||||
past_length=past_length,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
self.graph = decode_graph
|
||||
else:
|
||||
decode_graph.token_ids.copy_(next_input_ids)
|
||||
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
|
||||
return decode_graph
|
||||
|
||||
def _run_depth_decode_attention_core(
|
||||
self,
|
||||
layer_idx: int,
|
||||
stage: _DepthDecodeCudaGraphLayerStage,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(
|
||||
stage.key,
|
||||
stage.value,
|
||||
layer_idx,
|
||||
cache_kwargs,
|
||||
)
|
||||
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
||||
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
stage.query,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_bias,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
return attn_output.transpose(1, 2)
|
||||
|
||||
def run(
|
||||
self,
|
||||
next_input_ids: torch.Tensor,
|
||||
*,
|
||||
past_key_values: Cache,
|
||||
attention_bias: torch.Tensor,
|
||||
past_length: int,
|
||||
) -> tuple[torch.Tensor, Cache]:
|
||||
end = past_length + 1
|
||||
decode_graph = self._get_depth_decode_graph(
|
||||
next_input_ids,
|
||||
past_length=past_length,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
cache_position = decode_graph.positions[past_length:end]
|
||||
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
||||
|
||||
decode_graph.pre_graph.replay()
|
||||
|
||||
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
||||
attn_context = self._run_depth_decode_attention_core(
|
||||
layer_idx,
|
||||
decode_graph.stages[layer_idx],
|
||||
past_key_values=past_key_values,
|
||||
attention_bias=attention_bias_q,
|
||||
cache_position=cache_position,
|
||||
cos=decode_graph.cos,
|
||||
sin=decode_graph.sin,
|
||||
)
|
||||
post_graph.attn_context.copy_(attn_context)
|
||||
post_graph.graph.replay()
|
||||
|
||||
return decode_graph.output, past_key_values
|
||||
|
||||
|
||||
def _cuda_graph_tensor_signature(
|
||||
tensor: torch.Tensor | None,
|
||||
) -> tuple[Any, ...] | None:
|
||||
if tensor is None:
|
||||
return None
|
||||
return (
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
str(tensor.dtype),
|
||||
str(tensor.device),
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return (
|
||||
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
||||
sig(context.cross_mask),
|
||||
sig(context.self_mask),
|
||||
sig(context.valid_action),
|
||||
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return tuple(
|
||||
(
|
||||
sig(step.conditioning),
|
||||
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
|
||||
tuple(sig(t) for t in step.final_modulation),
|
||||
)
|
||||
for step in modulations
|
||||
)
|
||||
|
||||
|
||||
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
|
||||
sig = _cuda_graph_tensor_signature
|
||||
return (
|
||||
sig(inputs.trajectory),
|
||||
_cuda_graph_context_signature(inputs.context),
|
||||
_cuda_graph_modulation_signature(inputs.modulations),
|
||||
sig(inputs.action_dim_is_pad),
|
||||
int(steps),
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if tensor is None:
|
||||
return None
|
||||
static = torch.empty_strided(
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
device=tensor.device,
|
||||
dtype=tensor.dtype,
|
||||
)
|
||||
static.copy_(tensor)
|
||||
return static
|
||||
|
||||
|
||||
def _clone_static_context(context: Any) -> Any:
|
||||
rope_cache = None
|
||||
if context.rope_cache is not None:
|
||||
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
||||
return context.__class__(
|
||||
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
|
||||
cross_mask=_clone_static_tensor(context.cross_mask),
|
||||
self_mask=_clone_static_tensor(context.self_mask),
|
||||
valid_action=_clone_static_tensor(context.valid_action),
|
||||
rope_cache=rope_cache,
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
||||
return tuple(
|
||||
step.__class__(
|
||||
conditioning=_clone_static_tensor(step.conditioning),
|
||||
block_modulations=tuple(
|
||||
tuple(_clone_static_tensor(t) for t in block_modulation)
|
||||
for block_modulation in step.block_modulations
|
||||
),
|
||||
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
|
||||
)
|
||||
for step in modulations
|
||||
)
|
||||
|
||||
|
||||
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
||||
return _ActionFlowInputs(
|
||||
trajectory=_clone_static_tensor(inputs.trajectory),
|
||||
context=_clone_static_context(inputs.context),
|
||||
modulations=_clone_static_modulations(inputs.modulations),
|
||||
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
||||
)
|
||||
|
||||
|
||||
def _copy_context_(dst: Any, src: Any) -> None:
|
||||
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
||||
dst_k.copy_(src_k)
|
||||
dst_v.copy_(src_v)
|
||||
if src.cross_mask is not None:
|
||||
dst.cross_mask.copy_(src.cross_mask)
|
||||
if src.self_mask is not None:
|
||||
dst.self_mask.copy_(src.self_mask)
|
||||
if src.valid_action is not None:
|
||||
dst.valid_action.copy_(src.valid_action)
|
||||
if src.rope_cache is not None:
|
||||
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
||||
dst_tensor.copy_(src_tensor)
|
||||
|
||||
|
||||
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
||||
dst.trajectory.copy_(src.trajectory)
|
||||
_copy_context_(dst.context, src.context)
|
||||
if src.action_dim_is_pad is not None:
|
||||
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
||||
|
||||
|
||||
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _capture_cuda_graph(
|
||||
fn,
|
||||
device: torch.device,
|
||||
*,
|
||||
after_warmup=None,
|
||||
) -> tuple[torch.cuda.CUDAGraph, Any]:
|
||||
warmup_stream = torch.cuda.Stream(device=device)
|
||||
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(warmup_stream):
|
||||
fn()
|
||||
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
||||
if after_warmup is not None:
|
||||
after_warmup()
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
output = fn()
|
||||
return graph, output
|
||||
4591
src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py
Normal file
4591
src/lerobot/policies/molmoact2/hf_model/modeling_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
431
src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py
Normal file
431
src/lerobot/policies/molmoact2/hf_model/processing_molmoact2.py
Normal file
@@ -0,0 +1,431 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""
|
||||
Processor class for MolmoAct2.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
import dataclasses
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.video_utils import VideoInput
|
||||
from transformers.processing_utils import (
|
||||
Unpack,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
||||
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
||||
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
||||
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
||||
IM_START_TOKEN = f"<im_start>"
|
||||
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
||||
FRAME_START_TOKEN = f"<frame_start>"
|
||||
IM_END_TOKEN = f"<im_end>"
|
||||
FRAME_END_TOKEN = f"<frame_end>"
|
||||
IM_COL_TOKEN = f"<im_col>"
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
VIDEO_PROMPT = "<|video|>"
|
||||
|
||||
IMAGE_TOKENS = [
|
||||
IMAGE_PATCH_TOKEN,
|
||||
IM_COL_TOKEN,
|
||||
IM_START_TOKEN,
|
||||
LOW_RES_IMAGE_START_TOKEN,
|
||||
FRAME_START_TOKEN,
|
||||
IM_END_TOKEN,
|
||||
FRAME_END_TOKEN,
|
||||
IMAGE_LOW_RES_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
"""MolmoAct2 processor kwargs"""
|
||||
|
||||
images_kwargs: MolmoAct2ImagesKwargs
|
||||
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"return_mm_token_type_ids": True,
|
||||
},
|
||||
"videos_kwargs": {"return_metadata": True},
|
||||
}
|
||||
|
||||
|
||||
class MolmoAct2Processor(ProcessorMixin):
|
||||
attributes = ["image_processor", "video_processor", "tokenizer"]
|
||||
optional_attributes = [
|
||||
"chat_template",
|
||||
"time_mode",
|
||||
"image_use_col_tokens",
|
||||
"use_single_crop_col_tokens",
|
||||
"use_single_crop_start_token",
|
||||
"video_use_col_tokens",
|
||||
"use_frame_special_tokens",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: MolmoAct2ImageProcessor = None,
|
||||
video_processor: MolmoAct2VideoProcessor = None,
|
||||
tokenizer: AutoTokenizer = None,
|
||||
chat_template: str | None = None,
|
||||
image_use_col_tokens: bool | None = True,
|
||||
use_single_crop_col_tokens: bool | None = None,
|
||||
use_single_crop_start_token: bool | None = True,
|
||||
video_use_col_tokens: bool | None = False,
|
||||
use_frame_special_tokens: bool | None = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
image_processor,
|
||||
video_processor,
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
self.image_use_col_tokens = image_use_col_tokens
|
||||
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
||||
self.use_single_crop_start_token = use_single_crop_start_token
|
||||
self.video_use_col_tokens = video_use_col_tokens
|
||||
self.use_frame_special_tokens = use_frame_special_tokens
|
||||
|
||||
self.image_placeholder_token = IMAGE_PROMPT
|
||||
self.video_placeholder_token = VIDEO_PROMPT
|
||||
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
|
||||
|
||||
def get_image_tokens(self, image_grid: np.ndarray):
|
||||
resized_h, resized_w, height, width = image_grid
|
||||
if int(height) == 0 or int(width) == 0:
|
||||
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||
use_single_crop_col_tokens = (
|
||||
self.image_use_col_tokens
|
||||
if self.use_single_crop_col_tokens is None
|
||||
else self.use_single_crop_col_tokens
|
||||
)
|
||||
if use_single_crop_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[IM_START_TOKEN],
|
||||
np.tile(per_row, [resized_h]),
|
||||
[IM_END_TOKEN],
|
||||
]
|
||||
return np.concatenate(joint)
|
||||
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
||||
if self.image_use_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[IM_START_TOKEN],
|
||||
np.tile(per_row, [height]),
|
||||
[IM_END_TOKEN],
|
||||
]
|
||||
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||
use_single_crop_col_tokens = (
|
||||
self.image_use_col_tokens
|
||||
if self.use_single_crop_col_tokens is None
|
||||
else self.use_single_crop_col_tokens
|
||||
)
|
||||
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
|
||||
if use_single_crop_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
joint = [
|
||||
[image_start_token],
|
||||
np.tile(per_row, [resized_h]),
|
||||
[IM_END_TOKEN],
|
||||
] + joint
|
||||
|
||||
return np.concatenate(joint)
|
||||
|
||||
def get_video_string(
|
||||
self,
|
||||
video_grid: np.ndarray,
|
||||
timestamps: np.ndarray,
|
||||
):
|
||||
if self.use_frame_special_tokens:
|
||||
start_token_id = FRAME_START_TOKEN
|
||||
end_token_id = FRAME_END_TOKEN
|
||||
else:
|
||||
start_token_id = IM_START_TOKEN
|
||||
end_token_id = IM_END_TOKEN
|
||||
|
||||
num_frames, h, w = video_grid
|
||||
video_string: str = ""
|
||||
for frame_idx, frame_time in enumerate(timestamps):
|
||||
# `per-frame-compact` time mode
|
||||
prev_space = " " if frame_idx > 0 else ""
|
||||
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
||||
|
||||
video_string += frame_prefix
|
||||
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
||||
if self.video_use_col_tokens:
|
||||
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||
extra_tokens = np.tile(per_row, [h])
|
||||
video_tokens = [
|
||||
[start_token_id],
|
||||
extra_tokens,
|
||||
[end_token_id],
|
||||
]
|
||||
video_string += "".join(np.concatenate(video_tokens, 0))
|
||||
|
||||
return video_string
|
||||
|
||||
def insert_bos(
|
||||
self,
|
||||
input_ids: np.ndarray,
|
||||
attention_mask: np.ndarray,
|
||||
bos_token_id: int,
|
||||
pad_token_id: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_ids: [B, S] array with left padding
|
||||
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
||||
bos_token_id: int
|
||||
pad_token_id: int
|
||||
Returns:
|
||||
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
||||
attention_mask_out: same shape as input_ids_out
|
||||
"""
|
||||
|
||||
need_to_expand = len(input_ids.shape) == 1
|
||||
if need_to_expand:
|
||||
input_ids = input_ids[None, :]
|
||||
attention_mask = attention_mask[None, :]
|
||||
|
||||
B, S = input_ids.shape
|
||||
|
||||
# Handle zero-length sequence
|
||||
if S == 0:
|
||||
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
||||
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
||||
if need_to_expand:
|
||||
new_input_ids = new_input_ids[0]
|
||||
new_attention_mask = new_attention_mask[0]
|
||||
return new_input_ids, new_attention_mask
|
||||
|
||||
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
||||
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
||||
|
||||
if bos_already_present:
|
||||
if need_to_expand:
|
||||
input_ids = input_ids[0]
|
||||
attention_mask = attention_mask[0]
|
||||
return input_ids, attention_mask
|
||||
else:
|
||||
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
|
||||
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
|
||||
|
||||
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
||||
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
||||
tgt_idx = src_idx + 1 # shit right
|
||||
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
||||
|
||||
# flatten valid_positions
|
||||
flat_vals = input_ids[valid_mask]
|
||||
flat_batch = batch_idx[valid_mask]
|
||||
flat_tgt = tgt_idx[valid_mask]
|
||||
|
||||
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
||||
new_attention_mask[flat_batch, flat_tgt] = 1
|
||||
|
||||
insert_pos = first_valid_index
|
||||
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
||||
new_attention_mask[np.arange(B), insert_pos] = 1
|
||||
|
||||
if need_to_expand:
|
||||
new_input_ids = new_input_ids[0]
|
||||
new_attention_mask = new_attention_mask[0]
|
||||
|
||||
return new_input_ids, new_attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
images: ImageInput = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
|
||||
Args:
|
||||
text (`str`, `list[str]`, `list[list[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
||||
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
||||
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
||||
- `"timestamps"`: `np.ndarray` of shape (T,)
|
||||
- `"sampled_fps"`: `float` (optional)
|
||||
- `"sampling_augmentation"`: `str` (optional)
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
||||
Returned when `images` is not `None`.
|
||||
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
||||
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
||||
Returned when `videos` is not `None`.
|
||||
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
MolmoAct2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
image_grids = image_inputs["image_grids"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grids = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_grids = videos_inputs["video_grids"]
|
||||
# If user has not requested video metadata, pop it
|
||||
if "return_metadata" not in kwargs:
|
||||
video_metadata = videos_inputs.pop("video_metadata")
|
||||
else:
|
||||
video_metadata = videos_inputs["video_metadata"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grids = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
|
||||
if image_grids is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
num_images = text[i].count(self.image_placeholder_token)
|
||||
image_grids_i = image_grids[index : index + num_images]
|
||||
for image_grid in image_grids_i:
|
||||
image_tokens = self.get_image_tokens(image_grid)
|
||||
image_string = "".join(image_tokens)
|
||||
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
||||
index += num_images
|
||||
|
||||
if video_grids is not None:
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
num_videos = text[i].count(self.video_placeholder_token)
|
||||
assert num_videos in {0, 1}, "At most one video is supported for now"
|
||||
video_grids_i = video_grids[index : index + num_videos]
|
||||
metadata_i = video_metadata[index : index + num_videos]
|
||||
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
||||
video_string = self.get_video_string(
|
||||
video_grid,
|
||||
metadata.timestamps,
|
||||
)
|
||||
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
||||
index += num_videos
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
input_ids = text_inputs["input_ids"]
|
||||
attention_mask = text_inputs["attention_mask"]
|
||||
|
||||
input_ids = np.array(input_ids)
|
||||
attention_mask = np.array(attention_mask)
|
||||
|
||||
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
||||
input_ids, attention_mask = self.insert_bos(
|
||||
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
if return_mm_token_type_ids:
|
||||
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
||||
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
||||
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
||||
|
||||
text_inputs["input_ids"] = input_ids.tolist()
|
||||
text_inputs["attention_mask"] = attention_mask.tolist()
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs, **videos_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`list[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
MolmoAct2Processor.register_for_auto_class()
|
||||
@@ -0,0 +1,997 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Video processor class for MolmoAct2"""
|
||||
|
||||
from functools import partial
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import redirect_stdout
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms
|
||||
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.video_utils import (
|
||||
VideoInput,
|
||||
is_valid_video,
|
||||
make_batched_videos,
|
||||
make_batched_metadata,
|
||||
VideoMetadata,
|
||||
)
|
||||
from transformers.processing_utils import Unpack, VideosKwargs
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from transformers.utils import logging
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.utils import (
|
||||
is_av_available,
|
||||
is_decord_available,
|
||||
is_torchcodec_available,
|
||||
is_yt_dlp_available,
|
||||
TensorType,
|
||||
logging,
|
||||
to_numpy,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_VIDEO_FPS = 8
|
||||
|
||||
|
||||
def normalize_image(
|
||||
image: np.ndarray,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
) -> np.ndarray:
|
||||
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||
return image
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
desired_output_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
) -> np.ndarray:
|
||||
if len(image.shape) == 3:
|
||||
is_video = False
|
||||
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||
else:
|
||||
is_video = True
|
||||
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
|
||||
dtype = image.dtype
|
||||
if torch.is_floating_point(image):
|
||||
in_min = 0.0
|
||||
in_max = 1.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||
else:
|
||||
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||
image.dtype
|
||||
)
|
||||
in_min = 0.0
|
||||
in_max = 255.0
|
||||
resized = torchvision.transforms.Resize(
|
||||
desired_output_size,
|
||||
resample,
|
||||
antialias=False,
|
||||
)(image)
|
||||
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||
|
||||
resized = resized.to(torch.float32)
|
||||
resized = (resized - in_min) / (in_max - in_min)
|
||||
|
||||
if is_video:
|
||||
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
|
||||
else:
|
||||
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def build_resized_image(
|
||||
image: np.ndarray,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
resized = resize_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
)
|
||||
resized = normalize_image(resized, image_mean, image_std)
|
||||
if len(resized.shape) == 3:
|
||||
resized = np.expand_dims(resized, 0)
|
||||
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||
return resized, resize_idx
|
||||
|
||||
|
||||
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||
if len(array.shape) == 3:
|
||||
n_crops, h, w = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||
return array
|
||||
else:
|
||||
n_crops, h, w, c = array.shape
|
||||
h_patches = h // patch_size
|
||||
w_patches = w // patch_size
|
||||
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||
return array
|
||||
|
||||
|
||||
def arange_for_pooling(
|
||||
idx_arr: np.ndarray,
|
||||
pool_h: int,
|
||||
pool_w: int,
|
||||
) -> np.ndarray:
|
||||
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||
idx_arr = np.pad(
|
||||
idx_arr,
|
||||
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||
mode="constant",
|
||||
constant_values=-1,
|
||||
)
|
||||
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||
|
||||
|
||||
def image_to_patches_and_grids(
|
||||
image: ImageInput,
|
||||
base_image_input_size: list[int],
|
||||
resample: PILImageResampling,
|
||||
image_mean: list[float],
|
||||
image_std: list[float],
|
||||
image_patch_size: int,
|
||||
image_pooling_w: int,
|
||||
image_pooling_h: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
:return image_grids, the shape of each image after pooling
|
||||
:return crops, the image crops to processes with the ViT
|
||||
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||
patches in `crops` to pool for that token, masked with -1
|
||||
"""
|
||||
if isinstance(base_image_input_size, int):
|
||||
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||
|
||||
pooling_w = image_pooling_w
|
||||
pooling_h = image_pooling_h
|
||||
|
||||
resized, resize_idx = build_resized_image(
|
||||
image,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
image_patch_size,
|
||||
)
|
||||
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||
h, w = pooling_idx.shape[:2]
|
||||
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||
image_grid = [h, w]
|
||||
return (
|
||||
image_grid,
|
||||
batch_pixels_to_patches(resized, image_patch_size),
|
||||
pooling_idx,
|
||||
)
|
||||
|
||||
|
||||
def get_candidate_target_fps(
|
||||
video_fps: int | float,
|
||||
sampling_fps: int | float,
|
||||
max_fps: int | float = MAX_VIDEO_FPS,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
|
||||
|
||||
Examples:
|
||||
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
||||
[2, 6]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
||||
[1, 5]
|
||||
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
||||
[2]
|
||||
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
|
||||
"""
|
||||
video_fps = int(video_fps)
|
||||
sampling_fps = int(sampling_fps)
|
||||
max_fps = int(max_fps)
|
||||
|
||||
if sampling_fps is None:
|
||||
raise ValueError("sampling_fps must be provided")
|
||||
if video_fps <= 0 or sampling_fps <= 0:
|
||||
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
|
||||
if video_fps % sampling_fps != 0:
|
||||
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
|
||||
|
||||
candidates = []
|
||||
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
||||
if candidate > max_fps:
|
||||
break
|
||||
if video_fps % candidate == 0:
|
||||
candidates.append(float(candidate))
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def read_video_decord(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import from decord
|
||||
import importlib
|
||||
|
||||
decord = importlib.import_module("decord")
|
||||
|
||||
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
|
||||
duration = time_stamps[-1][1] - time_stamps[0][0]
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="decord",
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
offset = time_stamps[0, 0]
|
||||
|
||||
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
|
||||
ix = np.minimum(ix, len(time_stamps) - 1)
|
||||
|
||||
video = vr.get_batch(ix).asnumpy()
|
||||
metadata.update(
|
||||
{
|
||||
"frames_indices": target_timestamps * video_fps,
|
||||
"height": video.shape[1],
|
||||
"width": video.shape[2],
|
||||
}
|
||||
)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchcodec(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using torchcodec decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import torchcodec
|
||||
import importlib
|
||||
|
||||
torchcodec = importlib.import_module("torchcodec")
|
||||
|
||||
decoder = torchcodec.decoders.VideoDecoder(
|
||||
video_path,
|
||||
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
||||
seek_mode="exact",
|
||||
# Allow FFmpeg decide on the number of threads for efficiency
|
||||
num_ffmpeg_threads=0,
|
||||
)
|
||||
# If the first frame starts at > 0, we effectively clip the video starting at that time
|
||||
# since (most) video players would also skip to that time
|
||||
time_offset = decoder.metadata.begin_stream_seconds_from_content
|
||||
# Note this duration does assume we started playing at `time_offset`
|
||||
duration = decoder.metadata.duration_seconds
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=decoder.metadata.num_frames,
|
||||
fps=decoder.metadata.average_fps,
|
||||
duration=duration,
|
||||
video_backend="torchcodec",
|
||||
height=decoder.metadata.height,
|
||||
width=decoder.metadata.width,
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
|
||||
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
|
||||
# out-of-bounds, to handle this we sanity check then clip them
|
||||
assert all(x >= 0 for x in target_timestamps)
|
||||
assert all(x < duration + 1e-6 for x in target_timestamps)
|
||||
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
|
||||
# exact boundary value, we should still get the first/last frame anyway
|
||||
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
|
||||
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
|
||||
# Note we avoid using numpy ops here to reduce floating precision issues
|
||||
timestamps = [x + time_offset for x in target_timestamps]
|
||||
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
|
||||
|
||||
video = (
|
||||
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
|
||||
) # Convert to THWC format
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
metadata.frames_indices = target_timestamps * metadata.fps
|
||||
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_pyav(
|
||||
video_path,
|
||||
sample_timestamps_fn: Callable,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decode a video using the PyAV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
|
||||
Returns:
|
||||
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import torchcodec
|
||||
import importlib
|
||||
|
||||
av = importlib.import_module("av")
|
||||
|
||||
with av.open(video_path) as container:
|
||||
video_stream = container.streams.video[0]
|
||||
fps = video_stream.average_rate or video_stream.guessed_rate
|
||||
it = container.decode(video=0)
|
||||
frames = list(it)
|
||||
|
||||
stream = container.streams.video[0]
|
||||
start = frames[0].pts * stream.time_base
|
||||
container_end = stream.duration
|
||||
if container_end is not None:
|
||||
container_end *= stream.time_base
|
||||
if container_end is None or container_end < frames[-1].pts:
|
||||
# Some problem with stream duration, so use the frame PTS directly
|
||||
# and guess the duration of the last frame
|
||||
end = frames[-1].pts * stream.time_base + 1 / fps
|
||||
else:
|
||||
end = container_end
|
||||
duration = float(end - start)
|
||||
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=len(frames),
|
||||
fps=float(fps),
|
||||
duration=float(duration),
|
||||
video_backend="pyav",
|
||||
height=video_stream.height,
|
||||
width=video_stream.width,
|
||||
)
|
||||
|
||||
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
|
||||
offset = float(start)
|
||||
|
||||
target_timestamps = np.array(target_timestamps)
|
||||
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
|
||||
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
|
||||
indices = np.minimum(indices, len(end_time_stamps) - 1)
|
||||
|
||||
video = np.stack(
|
||||
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
metadata.frames_indices = target_timestamps * fps
|
||||
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
"decord": read_video_decord,
|
||||
"torchcodec": read_video_torchcodec,
|
||||
"pyav": read_video_pyav,
|
||||
}
|
||||
|
||||
|
||||
def load_video(
|
||||
video: VideoInput,
|
||||
backend: str = "decord",
|
||||
sample_timestamps_fn: Callable | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
video (`VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
backend (`str`, *optional*, defaults to `"decord"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
|
||||
sample_timestamps_fn (`Callable`):
|
||||
A callable function that will return timestamps at which the video should be sampled.
|
||||
"""
|
||||
|
||||
# Early exit if provided an array or `PIL` frames
|
||||
if not isinstance(video, str):
|
||||
metadata = [None] * len(video)
|
||||
return video, metadata
|
||||
|
||||
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
# Lazy import from yt_dlp
|
||||
import importlib
|
||||
|
||||
yt_dlp = importlib.import_module("yt_dlp")
|
||||
|
||||
buffer = BytesIO()
|
||||
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
|
||||
f.download([video])
|
||||
bytes_obj = buffer.getvalue()
|
||||
file_obj = BytesIO(bytes_obj)
|
||||
elif video.startswith("http://") or video.startswith("https://"):
|
||||
file_obj = BytesIO(requests.get(video, timeout=10).content)
|
||||
elif os.path.isfile(video):
|
||||
file_obj = video
|
||||
else:
|
||||
raise TypeError(
|
||||
"Incorrect format used for video. Should be an url linking to an video or a local path."
|
||||
)
|
||||
|
||||
# can also load with decord, but not cv2/torchvision
|
||||
# both will fail in case of url links
|
||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||
if video_is_url and backend == "opencv":
|
||||
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
||||
|
||||
if (
|
||||
(not is_decord_available() and backend == "decord")
|
||||
or (not is_torchcodec_available() and backend == "torchcodec")
|
||||
or (not is_av_available() and backend == "pyav")
|
||||
):
|
||||
raise ImportError(
|
||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||
f"Make sure to install {backend} before loading the video."
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def get_target_fps(
|
||||
video_fps: float,
|
||||
max_frames: int,
|
||||
total_frames: int,
|
||||
frame_sample_mode: str,
|
||||
candidate_target_fps: tuple[float],
|
||||
) -> float:
|
||||
"""
|
||||
Get the target fps that best spans the video and has the most frames sampled
|
||||
"""
|
||||
num_frames_sampled = 0
|
||||
selected_target_fps = None
|
||||
for target_fps in candidate_target_fps:
|
||||
step_size = max(int(video_fps / target_fps), 1)
|
||||
num_frames_sampled_at_fps = int(total_frames / step_size)
|
||||
if num_frames_sampled == 0:
|
||||
if "uniform" in frame_sample_mode:
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
break
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
|
||||
else:
|
||||
# the candidate sampling fps increases so frame count can't decrease
|
||||
assert num_frames_sampled <= num_frames_sampled_at_fps
|
||||
if num_frames_sampled_at_fps > max_frames:
|
||||
# choose the sampling fps that spans the video
|
||||
continue
|
||||
|
||||
elif num_frames_sampled_at_fps > num_frames_sampled:
|
||||
# both are less than max_frames, choose the one with higher density of frames sampled
|
||||
selected_target_fps = target_fps
|
||||
num_frames_sampled = num_frames_sampled_at_fps
|
||||
return selected_target_fps
|
||||
|
||||
|
||||
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
|
||||
if selected_target_fps is None:
|
||||
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
|
||||
else:
|
||||
step_size = max(int(video_fps / selected_target_fps), 1)
|
||||
frame_indices = np.arange(0, total_frames, step_size)
|
||||
if len(frame_indices) > max_frames:
|
||||
frame_indices = frame_indices[:max_frames]
|
||||
return selected_target_fps, frame_indices
|
||||
|
||||
|
||||
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
|
||||
patch_size: int | None
|
||||
pooling_size: list[int] | None
|
||||
frame_sample_mode: str | None
|
||||
max_fps: int | None
|
||||
sampling_fps: int | None
|
||||
|
||||
|
||||
class MolmoAct2VideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
size = {"height": 378, "width": 378}
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
patch_size = 14
|
||||
pooling_size = [3, 3]
|
||||
do_sample_frames = True
|
||||
frame_sample_mode = "uniform_last_frame"
|
||||
max_fps = 2
|
||||
sampling_fps = 2
|
||||
valid_kwargs = MolmoAct2VideoProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
if self.size is not None and (
|
||||
self.size.get("height", None) is None or self.size.get("width", None) is None
|
||||
):
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
def _further_process_kwargs(
|
||||
self,
|
||||
size: SizeDict | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Update kwargs that need further processing before being validated
|
||||
Can be overridden by subclasses to customize the processing of kwargs.
|
||||
"""
|
||||
if size is not None and ("height" not in size or "width" not in size):
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
return super()._further_process_kwargs(size=size, **kwargs)
|
||||
|
||||
def sample_times(
|
||||
self,
|
||||
metadata: VideoMetadata,
|
||||
frame_sample_mode: str,
|
||||
num_frames: int,
|
||||
max_fps: int | None = None,
|
||||
sampling_fps: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Time-based sampling if an array video is passed
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||
frame_sample_mode (`str`, *optional*):
|
||||
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
||||
num_frames (`int`, *optional*):
|
||||
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||
man_fps (`int`, *optional*):
|
||||
Maximum frames per second to sample.
|
||||
sampling_fps (`int`, *optional*):
|
||||
Sampling frames per second. Defaults to `self.sampling_fps`.
|
||||
Used when `frame_sample_mode` is `"fps"`.
|
||||
"""
|
||||
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
||||
num_frames = num_frames or self.num_frames
|
||||
sampling_fps = sampling_fps or self.sampling_fps
|
||||
|
||||
duration = metadata.duration or metadata.total_num_frames / metadata.fps
|
||||
if frame_sample_mode == "fps":
|
||||
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
||||
# Try larger and larger FPSs until we hit one that can't span the video
|
||||
target_fps = candidate_target_fps[0]
|
||||
for candidate_fps in candidate_target_fps[1:]:
|
||||
if num_frames / candidate_fps < duration:
|
||||
break
|
||||
target_fps = candidate_fps
|
||||
times = np.arange(0, num_frames) / target_fps
|
||||
times = times[times < duration]
|
||||
return times
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
if max_fps is not None:
|
||||
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
|
||||
if max_duration < duration:
|
||||
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
|
||||
else:
|
||||
times = np.arange(0.0, stop=duration, step=1 / max_fps)
|
||||
times = np.concatenate([times, [duration]], axis=0)
|
||||
assert len(times) <= num_frames
|
||||
else:
|
||||
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
|
||||
return times
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
def sample_frames(
|
||||
self,
|
||||
metadata: VideoMetadata,
|
||||
frame_sample_mode: str | None = None,
|
||||
num_frames: int | None = None,
|
||||
max_fps: int | None = None,
|
||||
sampling_fps: int | None = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Frame-based sampling if an array video is passed
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||
frame_sample_mode (`str`, *optional*):
|
||||
Mode to sample frames. Defaults to `self.frame_sample_mode`.
|
||||
num_frames (`int`, *optional*):
|
||||
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||
max_fps (`int`, *optional*):
|
||||
Maximum frames per second to sample.
|
||||
sampling_fps (`int`, *optional*):
|
||||
Sampling frames per second. Defaults to `self.sampling_fps`.
|
||||
Used when `frame_sample_mode` is `"fps"`.
|
||||
"""
|
||||
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
|
||||
num_frames = num_frames or self.num_frames
|
||||
sampling_fps = sampling_fps or self.sampling_fps
|
||||
|
||||
total_num_frames = metadata.total_num_frames
|
||||
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
||||
duration = total_num_frames / metadata.fps
|
||||
if total_num_frames <= 2:
|
||||
return np.arange(total_num_frames).astype(int)
|
||||
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
||||
# uniform fallback
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
return indices
|
||||
else:
|
||||
float_indices = np.arange(
|
||||
0.0,
|
||||
stop=total_num_frames - 1,
|
||||
step=float(metadata.fps / max_fps),
|
||||
)
|
||||
if np.round(float_indices[-1]) != total_num_frames - 1:
|
||||
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
|
||||
indices = np.round(float_indices).astype(int)
|
||||
assert indices[-1] < total_num_frames
|
||||
assert len(float_indices) <= num_frames
|
||||
return indices
|
||||
elif frame_sample_mode == "uniform_last_frame":
|
||||
indices = np.linspace(
|
||||
0,
|
||||
total_num_frames - 1,
|
||||
num=min(num_frames, total_num_frames),
|
||||
endpoint=True,
|
||||
).astype(int)
|
||||
return indices
|
||||
elif frame_sample_mode == "fps":
|
||||
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
|
||||
selected_target_fps = get_target_fps(
|
||||
metadata.fps,
|
||||
num_frames,
|
||||
total_num_frames,
|
||||
frame_sample_mode,
|
||||
candidate_target_fps,
|
||||
)
|
||||
_, indices = get_frame_times_and_chosen_fps(
|
||||
selected_target_fps,
|
||||
total_num_frames,
|
||||
num_frames,
|
||||
metadata.fps,
|
||||
)
|
||||
return indices
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
|
||||
"""
|
||||
Convert a single or a list of urls into the corresponding `np.array` objects.
|
||||
|
||||
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||
returned.
|
||||
"""
|
||||
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
|
||||
raise ImportError(
|
||||
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
|
||||
)
|
||||
|
||||
if is_decord_available():
|
||||
backend = "decord"
|
||||
elif is_torchcodec_available():
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `torchcodec`."
|
||||
)
|
||||
backend = "torchcodec"
|
||||
else:
|
||||
warnings.warn(
|
||||
"`decord` is not installed and cannot be used to decode the video by default. "
|
||||
"Falling back to `PyAV`."
|
||||
)
|
||||
backend = "pyav"
|
||||
|
||||
if isinstance(video_url_or_urls, list):
|
||||
return list(
|
||||
zip(
|
||||
*[
|
||||
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
|
||||
for x in video_url_or_urls
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
|
||||
|
||||
def _decode_and_sample_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
video_metadata: VideoMetadata | dict,
|
||||
do_sample_frames: bool | None = None,
|
||||
sample_indices_fn: Callable | None = None,
|
||||
sample_timestamps_fn: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Decode input videos and sample frames if needed.
|
||||
"""
|
||||
videos = make_batched_videos(videos)
|
||||
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
|
||||
|
||||
# Framed-based sampling if an array video is passed
|
||||
# Otherwise, time-based sampling with decoding
|
||||
if is_valid_video(videos[0]) and do_sample_frames:
|
||||
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
|
||||
sampled_videos = []
|
||||
sampled_metadata = []
|
||||
for video, metadata in zip(videos, video_metadata):
|
||||
indices = sample_indices_fn(metadata=metadata)
|
||||
metadata.frames_indices = indices
|
||||
sampled_videos.append(video[indices])
|
||||
sampled_metadata.append(metadata)
|
||||
videos = sampled_videos
|
||||
video_metadata = sampled_metadata
|
||||
elif not is_valid_video(videos[0]):
|
||||
if sample_indices_fn is None:
|
||||
logger.warning(
|
||||
"do_sample_frames is False, but video array is not provided: "
|
||||
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
|
||||
)
|
||||
if isinstance(videos[0], list):
|
||||
raise ValueError("A list of images is not supported for video input!")
|
||||
else:
|
||||
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
|
||||
|
||||
return videos, video_metadata
|
||||
|
||||
def _prepare_input_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs,
|
||||
) -> list[np.ndarray]:
|
||||
processed_videos = [to_numpy(video) for video in videos]
|
||||
return processed_videos
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
|
||||
)
|
||||
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
do_sample_frames = kwargs.pop("do_sample_frames")
|
||||
video_metadata = kwargs.pop("video_metadata")
|
||||
|
||||
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
|
||||
sample_timestamps_fn = partial(self.sample_times, **kwargs)
|
||||
videos, video_metadata = self._decode_and_sample_videos(
|
||||
videos,
|
||||
video_metadata=video_metadata,
|
||||
do_sample_frames=do_sample_frames,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
sample_timestamps_fn=sample_timestamps_fn,
|
||||
)
|
||||
videos = self._prepare_input_videos(videos=videos)
|
||||
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
return_metadata = kwargs.pop("return_metadata")
|
||||
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
|
||||
if return_metadata:
|
||||
preprocessed_videos["video_metadata"] = video_metadata
|
||||
return preprocessed_videos
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list[np.ndarray],
|
||||
size: SizeDict | None = None,
|
||||
resample: PILImageResampling | None = None,
|
||||
image_mean: float | list[float] | None = None,
|
||||
image_std: float | list[float] | None = None,
|
||||
do_convert_rgb: bool | None = None,
|
||||
patch_size: int | None = None,
|
||||
pooling_size: list[int] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Preprocess a video for the model.
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video to preprocess.
|
||||
size (`SizeDict`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||
The pooling size of the vision adapter.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
|
||||
Returns:
|
||||
A `BatchFeature` containing the following keys:
|
||||
- `pixel_values_videos`: The preprocessed videos.
|
||||
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
|
||||
- `video_grids`: The video grids.
|
||||
"""
|
||||
if size.height is None or size.width is None:
|
||||
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||
|
||||
base_image_input_size = [size.height, size.width]
|
||||
|
||||
resample = resample or self.resample
|
||||
image_mean = image_mean or self.image_mean
|
||||
image_std = image_std or self.image_std
|
||||
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||
|
||||
patch_size = patch_size or self.patch_size
|
||||
pooling_size = pooling_size or self.pooling_size
|
||||
|
||||
image_pooling_h, image_pooling_w = pooling_size
|
||||
|
||||
batch_grids = []
|
||||
batch_crops = []
|
||||
batch_pooled_patches_idx = []
|
||||
|
||||
for video in videos:
|
||||
all_crops = []
|
||||
pooled_patches_idx = []
|
||||
|
||||
for frame in video:
|
||||
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||
frame,
|
||||
base_image_input_size,
|
||||
resample,
|
||||
image_mean,
|
||||
image_std,
|
||||
patch_size,
|
||||
image_pooling_w,
|
||||
image_pooling_h,
|
||||
)
|
||||
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
|
||||
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
|
||||
pooled_patches_idx.append(pooled_idx_with_offset)
|
||||
all_crops.append(crops)
|
||||
|
||||
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
|
||||
all_crops = np.concatenate(all_crops, 0)
|
||||
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
|
||||
|
||||
batch_grids.append(video_grid)
|
||||
batch_crops.append(all_crops)
|
||||
batch_pooled_patches_idx.append(pooled_patches_idx)
|
||||
|
||||
video_grids = np.stack(batch_grids, 0)
|
||||
pixel_values_videos = np.concatenate(batch_crops, 0)
|
||||
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||
|
||||
data = dict(
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
video_token_pooling=video_token_pooling,
|
||||
video_grids=video_grids,
|
||||
)
|
||||
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
MolmoAct2VideoProcessor.register_for_auto_class()
|
||||
1551
src/lerobot/policies/molmoact2/modeling_molmoact2.py
Normal file
1551
src/lerobot/policies/molmoact2/modeling_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
1083
src/lerobot/policies/molmoact2/processor_molmoact2.py
Normal file
1083
src/lerobot/policies/molmoact2/processor_molmoact2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -255,15 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -271,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -441,13 +449,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -662,8 +674,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -881,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
/home/maxime/github/robots/lerobot/docs/source/policy_vla_jepa_README.md
|
||||
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"VLAJEPANewLineProcessor",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
327
src/lerobot/policies/vla_jepa/action_head.py
Normal file
327
src/lerobot/policies/vla_jepa/action_head.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
138
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
138
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
615
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
615
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
@@ -0,0 +1,615 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = config.jepa_tubelet_size
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
139
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
139
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] >= 7:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., 6] = (a[..., 6] >= 0.5).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes gripper dim (index 6) after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
|
||||
Only applied when action has >= 7 dimensions.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] >= 7:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(PreSnapGripperProcessorStep())
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(BinarizeGripperProcessorStep())
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
|
||||
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
def complementary_data(self, complementary_data):
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
103
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
103
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
404
src/lerobot/policies/vla_jepa/world_model.py
Normal file
404
src/lerobot/policies/vla_jepa/world_model.py
Normal file
@@ -0,0 +1,404 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -95,6 +95,13 @@ from .relative_action_processor import (
|
||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
|
||||
# `lerobot.datasets.language`, which requires the `[dataset]` extra
|
||||
# (`datasets`, `pyarrow`). Importing it from the processor package would
|
||||
# break every base-install consumer of `lerobot.processor`. Users that
|
||||
# need it import directly:
|
||||
# from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
|
||||
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
complementary_data["messages"] = [messages]
|
||||
|
||||
if "message_streams" in complementary_data:
|
||||
streams = complementary_data["message_streams"]
|
||||
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||
complementary_data["message_streams"] = [streams]
|
||||
|
||||
if "target_message_indices" in complementary_data:
|
||||
indices = complementary_data["target_message_indices"]
|
||||
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||
complementary_data["target_message_indices"] = [indices]
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -153,26 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
_COMPLEMENTARY_KEYS = (
|
||||
"task",
|
||||
"index",
|
||||
"task_index",
|
||||
"episode_index",
|
||||
"timestamp",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
"messages",
|
||||
"message_streams",
|
||||
"target_message_indices",
|
||||
)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
"""Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||
each only when present in ``batch``.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||
return {**pad_keys, **extras}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
84
src/lerobot/processor/render_messages_processor.py
Normal file
84
src/lerobot/processor/render_messages_processor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
"""Processor step that turns raw language columns into rendered chat messages.
|
||||
|
||||
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||
and replaces the raw columns with the resulting ``messages`` /
|
||||
``message_streams`` / ``target_message_indices`` keys.
|
||||
"""
|
||||
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
sample_idx = complementary_data.get("index", 0)
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=unwrap_scalar(timestamp),
|
||||
sample_idx=int(unwrap_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
@@ -21,11 +21,13 @@ from .factory import (
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
"TOPRewardConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
|
||||
@@ -17,10 +17,11 @@ import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
from ..pretrained import PreTrainedRewardModel
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
@@ -25,7 +25,8 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
|
||||
@@ -22,9 +22,11 @@ import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
@@ -36,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm".
|
||||
"sarm", "topreward".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
@@ -52,6 +54,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "topreward":
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -68,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm".
|
||||
"reward_classifier", "sarm", "topreward".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -81,6 +87,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -161,6 +169,14 @@ def make_reward_pre_post_processors(
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, TOPRewardConfig):
|
||||
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
return make_topreward_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -58,9 +58,10 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
from .sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
|
||||
@@ -32,13 +32,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
from ..pretrained import PreTrainedRewardModel
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
|
||||
@@ -58,15 +58,16 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
19
src/lerobot/rewards/topreward/__init__.py
Normal file
19
src/lerobot/rewards/topreward/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_topreward import TOPRewardConfig
|
||||
from .modeling_topreward import TOPRewardModel
|
||||
from .processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
__all__ = ["TOPRewardConfig", "TOPRewardModel", "make_topreward_pre_post_processors"]
|
||||
353
src/lerobot/rewards/topreward/compute_rabc_weights.py
Normal file
353
src/lerobot/rewards/topreward/compute_rabc_weights.py
Normal file
@@ -0,0 +1,353 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame TOPReward progress curves for a LeRobot dataset.
|
||||
|
||||
For each episode, scores trajectory prefixes of increasing length using
|
||||
the TOPReward reward model, min-max normalises the raw log-prob rewards per episode,
|
||||
and writes a parquet file with one row per frame.
|
||||
|
||||
The parquet uses the same schema as SARM's :mod:`lerobot.rewards.sarm.compute_rabc_weights`.
|
||||
|
||||
Usage:
|
||||
# Sparse-dense mode (15 anchors per episode, matches upstream)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-samples 15
|
||||
|
||||
# Use a different VLM backbone
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPRewardEncoderProcessorStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "topreward_progress.parquet"
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def normalize_rewards(rewards: list[float] | np.ndarray) -> np.ndarray:
|
||||
"""Min-max normalise raw log-prob rewards into ``[0, 1]``."""
|
||||
rewards_arr = np.asarray(rewards, dtype=np.float64)
|
||||
if rewards_arr.size == 0:
|
||||
return rewards_arr.astype(np.float32)
|
||||
if rewards_arr.size == 1:
|
||||
return np.array([1.0], dtype=np.float32)
|
||||
r_min, r_max = rewards_arr.min(), rewards_arr.max()
|
||||
if r_max == r_min:
|
||||
return np.ones_like(rewards_arr, dtype=np.float32)
|
||||
return ((rewards_arr - r_min) / (r_max - r_min)).astype(np.float32)
|
||||
|
||||
|
||||
def compute_instruction_rewards_for_prefixes(
|
||||
model: TOPRewardModel,
|
||||
encoder: TOPRewardEncoderProcessorStep,
|
||||
dataset: LeRobotDataset,
|
||||
ep_start: int,
|
||||
num_frames: int,
|
||||
task: str,
|
||||
image_key: str,
|
||||
num_samples: int | None,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
"""Score an episode via prefix sweep and return a per-frame normalised curve."""
|
||||
if num_samples is None or num_samples >= num_frames:
|
||||
prefix_lengths = np.arange(1, num_frames + 1, dtype=np.int64)
|
||||
else:
|
||||
prefix_lengths = np.unique(np.linspace(1, num_frames, num_samples).round().astype(np.int64))
|
||||
|
||||
episode_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
rewards: list[float] = []
|
||||
for length in prefix_lengths:
|
||||
frames = episode_frames[: int(length)].unsqueeze(0) # (1, T, C, H, W)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {image_key: frames},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
|
||||
}
|
||||
encoded = encoder(transition)
|
||||
obs = encoded[TransitionKey.OBSERVATION]
|
||||
batch = {
|
||||
key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in obs.items()
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
reward = model.compute_reward(batch)
|
||||
rewards.append(float(reward.item()))
|
||||
|
||||
normalized_rewards = normalize_rewards(rewards)
|
||||
|
||||
if prefix_lengths.shape[0] == num_frames:
|
||||
return normalized_rewards
|
||||
|
||||
return np.interp(
|
||||
np.arange(1, num_frames + 1, dtype=np.float64),
|
||||
prefix_lengths.astype(np.float64),
|
||||
normalized_rewards.astype(np.float64),
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def compute_topreward_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str | None = None,
|
||||
vlm_name: str | None = None,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
num_samples: int | None = None,
|
||||
fps: float | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
) -> Path:
|
||||
"""Run TOPReward over a dataset and write per-frame progress."""
|
||||
if reward_model_path is not None:
|
||||
logging.info(f"Loading TOPReward config from: {reward_model_path}")
|
||||
model = TOPRewardModel.from_pretrained(reward_model_path)
|
||||
config = model.config
|
||||
config.device = device
|
||||
if vlm_name is not None and vlm_name != config.vlm_name:
|
||||
logging.info(f"Overriding vlm_name from config: {config.vlm_name} -> {vlm_name}")
|
||||
config.vlm_name = vlm_name
|
||||
model = TOPRewardModel(config)
|
||||
else:
|
||||
config_kwargs: dict[str, Any] = {"device": device}
|
||||
if vlm_name is not None:
|
||||
config_kwargs["vlm_name"] = vlm_name
|
||||
if fps is not None:
|
||||
config_kwargs["fps"] = fps
|
||||
config = TOPRewardConfig(**config_kwargs)
|
||||
logging.info(f"Constructing TOPReward with VLM: {config.vlm_name}")
|
||||
model = TOPRewardModel(config)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
encoder = TOPRewardEncoderProcessorStep(
|
||||
vlm_name=config.vlm_name,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=None, # no tail-crop: we control prefix length explicitly
|
||||
fps=config.fps,
|
||||
prompt_prefix=config.prompt_prefix,
|
||||
prompt_suffix_template=config.prompt_suffix_template,
|
||||
add_chat_template=config.add_chat_template,
|
||||
max_length=config.max_input_length,
|
||||
)
|
||||
|
||||
image_key = config.image_key
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
|
||||
logging.info(f"Processing {len(episode_indices)} episode(s)")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(episode_indices, desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
per_frame = compute_instruction_rewards_for_prefixes(
|
||||
model=model,
|
||||
encoder=encoder,
|
||||
dataset=dataset,
|
||||
ep_start=ep_start,
|
||||
num_frames=num_frames,
|
||||
task=task,
|
||||
image_key=image_key,
|
||||
num_samples=num_samples,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
schema_metadata: dict[bytes, bytes] = {b"vlm_name": config.vlm_name.encode()}
|
||||
if reward_model_path is not None:
|
||||
schema_metadata[b"reward_model_path"] = reward_model_path.encode()
|
||||
table = table.replace_schema_metadata(schema_metadata)
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame TOPReward progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Sparse-dense mode (matches upstream TOPReward num_samples=15)
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--num-samples 15
|
||||
|
||||
# Use a smaller VLM
|
||||
python -m lerobot.rewards.topreward.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--vlm-name Qwen/Qwen3-VL-4B-Instruct
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path", type=str, default=None, help="Optional TOPReward LeRobot config."
|
||||
)
|
||||
parser.add_argument("--vlm-name", type=str, default=None, help="Override the VLM backbone (HF Hub id).")
|
||||
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Anchor prefix samples per episode. None = dense. 15 matches upstream.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Process only these episode indices (e.g. --episodes 0 or --episodes 0 5 10).",
|
||||
)
|
||||
parser.add_argument("--fps", type=float, default=None, help="Override TOPRewardConfig.fps.")
|
||||
parser.add_argument(
|
||||
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
output_path = compute_topreward_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=args.reward_model_path,
|
||||
vlm_name=args.vlm_name,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
num_samples=args.num_samples,
|
||||
fps=args.fps,
|
||||
episodes=args.episodes,
|
||||
)
|
||||
|
||||
print(f"\nTOPReward progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
src/lerobot/rewards/topreward/configuration_topreward.py
Normal file
146
src/lerobot/rewards/topreward/configuration_topreward.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Default prompt scaffolding from the upstream TOPReward paper / reference
|
||||
# implementation (``QwenClient.compute_instruction_reward``). The prompt
|
||||
# scores the terminal ``True`` token in ``f"{instruction} ... True"``
|
||||
# given the video.
|
||||
DEFAULT_PROMPT_PREFIX = (
|
||||
"The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
)
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE = (
|
||||
"{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("topreward")
|
||||
@dataclass
|
||||
class TOPRewardConfig(RewardModelConfig):
|
||||
"""Configuration for the TOPReward zero-shot reward model.
|
||||
|
||||
TOPReward is **zero-shot**: it has no learnable parameters of its own.
|
||||
The "model" is a generic vision-language model (default
|
||||
``Qwen/Qwen3-VL-8B-Instruct``) used with a fixed prompt to extract
|
||||
token log-probabilities as a reward signal. There is therefore no
|
||||
fine-tuned checkpoint to host: ``pretrained_path`` is unused at
|
||||
runtime — the model identity is :attr:`vlm_name` (an HF Hub id).
|
||||
|
||||
Args:
|
||||
vlm_name: Hugging Face Hub id of the underlying VLM. Must be a
|
||||
Qwen3-VL family model (the only client implemented in this
|
||||
LeRobot port).
|
||||
torch_dtype: Torch dtype name passed to the VLM loader
|
||||
(``"auto"``, ``"bfloat16"``, ``"float16"``, ...).
|
||||
attn_implementation: ``transformers`` attention implementation
|
||||
(e.g. ``"flash_attention_2"``, ``"sdpa"``). Defaults to
|
||||
``None`` so the upstream picks the best available.
|
||||
image_key: Observation key that holds the trajectory frames.
|
||||
task_key: Complementary-data key that holds the task instruction.
|
||||
default_task: Fallback instruction when ``task_key`` is absent.
|
||||
max_frames: Cap on the number of frames fed to the VLM per
|
||||
sample. ``None`` = use all frames.
|
||||
fps: Frames-per-second metadata for the Qwen video processor.
|
||||
prompt_prefix: Text shown to the VLM right after the video and
|
||||
before the suffix template.
|
||||
prompt_suffix_template: Suffix appended after ``prompt_prefix``.
|
||||
Must contain ``{instruction}``; the VLM scores the
|
||||
log-likelihood of the tokens that follow the prefix.
|
||||
add_chat_template: If ``True``, wrap the full prompt with the
|
||||
tokenizer's chat template before tokenisation (matches
|
||||
upstream ``add_chat_template=True``).
|
||||
success_threshold: Optional log-prob threshold. If finite,
|
||||
:meth:`TOPRewardModel.compute_reward` returns
|
||||
``(reward > success_threshold).float()`` instead of the raw
|
||||
log-prob.
|
||||
max_input_length: Hard limit on the total tokenized input length;
|
||||
samples that exceed it raise a ``ValueError``.
|
||||
"""
|
||||
|
||||
# Path to a local LeRobot dir or HF repo that holds a ``config.json``
|
||||
# snapshot of this TOPRewardConfig. The VLM weights themselves are
|
||||
# always identified by ``vlm_name``.
|
||||
pretrained_path: str | None = None
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
torch_dtype: str = "auto"
|
||||
attn_implementation: str | None = None
|
||||
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
|
||||
success_threshold: float = float("-inf")
|
||||
max_input_length: int = 32768
|
||||
|
||||
license: str | None = "mit" # matches upstream TOPReward
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be > 0, got {self.fps}")
|
||||
if "{instruction}" not in self.prompt_suffix_template:
|
||||
raise ValueError(
|
||||
"prompt_suffix_template must contain `{instruction}` so the model "
|
||||
"scores the log-likelihood of the task suffix."
|
||||
)
|
||||
if self.max_input_length <= 0:
|
||||
raise ValueError(f"max_input_length must be > 0, got {self.max_input_length}")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("reward", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"TOPReward requires image input feature {self.image_key!r}")
|
||||
238
src/lerobot/rewards/topreward/modeling_topreward.py
Normal file
238
src/lerobot/rewards/topreward/modeling_topreward.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright 2026 Shirui Chen, Cole Harrison, Ying-Chun Lee, Angela Jin Yang,
|
||||
# Zhongzheng Ren, Lillian J. Ratliff, Jiafei Duan, Dieter Fox, Ranjay Krishna
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics.
|
||||
|
||||
Paper: https://arxiv.org/abs/2602.19313
|
||||
Project: https://topreward.github.io/webpage/
|
||||
Original code: https://github.com/TOPReward/TOPReward
|
||||
Backbone: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct (default)
|
||||
|
||||
TOPReward is a **zero-shot** reward model: it has no fine-tuned weights of
|
||||
its own. Given a video trajectory and a task instruction, it asks an
|
||||
off-the-shelf VLM how likely the instruction is, conditioned on the video,
|
||||
and returns that log-likelihood as the reward signal.
|
||||
|
||||
Inference recipe:
|
||||
|
||||
1. The processor builds a chat-style prompt, tokenises it, and emits
|
||||
``input_ids``, ``attention_mask``, vision tensors, and ``labels``.
|
||||
The processor label-masks everything except the terminal answer token with
|
||||
``-100``.
|
||||
2. Forward the full token sequence through the VLM.
|
||||
3. Read the terminal answer token log-probability from the logits as the
|
||||
scalar reward.
|
||||
|
||||
With the default ``prompt_suffix_template``, the only unmasked token is the
|
||||
literal ``"True"`` at the end — the reward is
|
||||
``log P("True" | video + prompt + instruction)``.
|
||||
|
||||
This LeRobot port is **inference-only and not trainable** — :meth:`forward`
|
||||
is intentionally inherited from :class:`PreTrainedRewardModel` and raises
|
||||
``NotImplementedError``, making :attr:`PreTrainedRewardModel.is_trainable`
|
||||
return ``False``.
|
||||
|
||||
Because the VLM weights live on the Hugging Face Hub under their canonical
|
||||
id (``Qwen/Qwen3-VL-8B-Instruct`` etc.) and TOPReward never modifies them,
|
||||
:meth:`_save_pretrained` and :meth:`from_pretrained` are overridden so a
|
||||
TOPReward LeRobot "checkpoint" is a single ``config.json`` (the VLM is
|
||||
re-fetched from the Hub at load time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import cross_entropy
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX, TOPREWARD_INPUT_KEYS
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
Qwen3VLForConditionalGeneration = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="TOPRewardModel")
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype | str:
|
||||
"""Resolve a torch dtype name; ``"auto"`` is passed through verbatim."""
|
||||
if name == "auto":
|
||||
return "auto"
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
class TOPRewardModel(PreTrainedRewardModel):
|
||||
"""TOPReward zero-shot reward model."""
|
||||
|
||||
name = "topreward"
|
||||
config_class = TOPRewardConfig
|
||||
|
||||
def __init__(self, config: TOPRewardConfig) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
model_kwargs: dict[str, Any] = {"dtype": torch_dtype, "trust_remote_code": True}
|
||||
if config.attn_implementation is not None:
|
||||
model_kwargs["attn_implementation"] = config.attn_implementation
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(config.vlm_name, **model_kwargs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Return one log-prob reward per sample in the batch."""
|
||||
inputs: dict[str, Any] = {}
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
batch_key = f"{TOPREWARD_FEATURE_PREFIX}{key}"
|
||||
if batch_key not in batch:
|
||||
raise KeyError(
|
||||
f"TOPReward batch missing `{batch_key}`. Make sure the "
|
||||
"TOPRewardEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
inputs[key] = batch[batch_key]
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
|
||||
labels = inputs.pop("labels")
|
||||
inputs["logits_to_keep"] = 2
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
rewards = -cross_entropy(logits[:, -2, :].float(), labels[:, -1], reduction="none")
|
||||
if np.isfinite(self.config.success_threshold):
|
||||
rewards = (rewards > self.config.success_threshold).float()
|
||||
return rewards.to(self.config.device or "cpu")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save ``config.json`` only."""
|
||||
self.config._save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False, # noqa: ARG003 — accepted for API parity; unused (no safetensors to load)
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Load a TOPReward configuration and instantiate the wrapped VLM."""
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(config, TOPRewardConfig):
|
||||
raise TypeError(
|
||||
f"Expected a TOPRewardConfig, got {type(config).__name__}. Make sure "
|
||||
f"`pretrained_name_or_path={pretrained_name_or_path!r}` points at a "
|
||||
"TOPReward checkpoint."
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
if not os.path.isdir(model_id):
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
instance = cls(config, **kwargs)
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
def push_model_to_hub(self, cfg: TrainPipelineConfig):
|
||||
"""Push the TOPReward ``config.json`` + model card to the Hub."""
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
saved_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.config._save_pretrained(saved_path)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path)
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload TOPReward config and readme",
|
||||
allow_patterns=["*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log", "*.safetensors"],
|
||||
)
|
||||
|
||||
logger.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
305
src/lerobot/rewards/topreward/processor_topreward.py
Normal file
305
src/lerobot/rewards/topreward/processor_topreward.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TOPReward pre/post processing pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.topreward.configuration_topreward import (
|
||||
DEFAULT_PROMPT_PREFIX,
|
||||
DEFAULT_PROMPT_SUFFIX_TEMPLATE,
|
||||
TOPRewardConfig,
|
||||
)
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
TOPREWARD_FEATURE_PREFIX = f"{OBS_PREFIX}topreward."
|
||||
|
||||
_TRUE_ANSWER = "True"
|
||||
|
||||
TOPREWARD_VLM_INPUT_KEYS = (
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
TOPREWARD_INPUT_KEYS = TOPREWARD_VLM_INPUT_KEYS + ("labels",)
|
||||
|
||||
|
||||
def _prepare_video_batch(video: Tensor, *, max_frames: int | None) -> Tensor:
|
||||
"""Return videos as ``(B, T, C, H, W)`` uint8 tensors for Qwen3-VL."""
|
||||
if video.ndim == 4:
|
||||
video = video.unsqueeze(1)
|
||||
elif video.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected TOPReward frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(video.shape)}"
|
||||
)
|
||||
|
||||
if max_frames is not None:
|
||||
video = video[:, -max_frames:]
|
||||
if video.shape[-1] in (1, 3):
|
||||
video = video.permute(0, 1, 4, 2, 3)
|
||||
elif video.shape[2] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
if video.is_floating_point():
|
||||
video = video * 255.0
|
||||
|
||||
return video.clamp(0, 255).to(torch.uint8).contiguous()
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("TOPReward expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"TOPReward task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="topreward_encoder")
|
||||
class TOPRewardEncoderProcessorStep(ProcessorStep):
|
||||
"""Encode raw frames + task into Qwen-VL tensors for the TOPReward model.
|
||||
|
||||
Loads a :class:`~transformers.AutoProcessor` matching ``vlm_name`` and
|
||||
builds the full chat prompt including the instruction suffix. The
|
||||
resulting ``input_ids``, ``attention_mask``, vision tensors, and
|
||||
``labels`` are written under the ``observation.topreward.*`` namespace
|
||||
so the model can score without re-tokenising.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes ``observation[f"{TOPREWARD_FEATURE_PREFIX}<name>"]`` for the
|
||||
Qwen-VL tensors plus ``labels``.
|
||||
"""
|
||||
|
||||
vlm_name: str = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 16
|
||||
fps: float = 2.0
|
||||
prompt_prefix: str = DEFAULT_PROMPT_PREFIX
|
||||
prompt_suffix_template: str = DEFAULT_PROMPT_SUFFIX_TEMPLATE
|
||||
add_chat_template: bool = False
|
||||
max_length: int = 32768
|
||||
|
||||
_processor: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
require_package("transformers", extra="topreward")
|
||||
self._processor = AutoProcessor.from_pretrained(self.vlm_name, trust_remote_code=True)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"TOPReward expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
videos = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
videos = _prepare_video_batch(videos, max_frames=self.max_frames)
|
||||
|
||||
batch_size = videos.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
encoded = self._encode_batch(videos, tasks, batch_size)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for key, value in encoded.items():
|
||||
new_observation[f"{TOPREWARD_FEATURE_PREFIX}{key}"] = value
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def _encode_batch(self, videos: Tensor, tasks: list[str], batch_size) -> dict[str, Any]:
|
||||
"""Tokenise a batch of (frames, task) pairs into Qwen-VL tensors.
|
||||
|
||||
The loop only builds per-sample chat strings. Tokenisation, padding,
|
||||
video preprocessing, and label construction are batched.
|
||||
"""
|
||||
|
||||
texts: list[str] = []
|
||||
video_metadata = [
|
||||
{
|
||||
"total_num_frames": int(videos.shape[1]),
|
||||
"fps": float(self.fps),
|
||||
"frames_indices": list(range(int(videos.shape[1]))),
|
||||
}
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
eos_token = self._processor.tokenizer.eos_token
|
||||
|
||||
for i in range(batch_size):
|
||||
instruction_suffix = self.prompt_suffix_template.format(instruction=tasks[i])
|
||||
if self.add_chat_template:
|
||||
suffix_for_template = instruction_suffix.removesuffix(_TRUE_ANSWER).rstrip()
|
||||
templated_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": videos[i], "fps": self.fps},
|
||||
{"type": "text", "text": f"{self.prompt_prefix}{suffix_for_template}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self._processor.apply_chat_template(
|
||||
templated_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
full_text = f"{prompt_chat}{_TRUE_ANSWER}"
|
||||
else:
|
||||
user_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": videos[i], "fps": self.fps},
|
||||
{"type": "text", "text": self.prompt_prefix},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt_chat = self._processor.apply_chat_template(
|
||||
user_messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
if eos_token is not None:
|
||||
prompt_chat = prompt_chat.split(eos_token)[0]
|
||||
full_text = f"{prompt_chat}{instruction_suffix}"
|
||||
|
||||
texts.append(full_text)
|
||||
|
||||
result = self._processor(
|
||||
text=texts,
|
||||
videos=videos,
|
||||
video_metadata=video_metadata,
|
||||
do_sample_frames=False,
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = result["input_ids"]
|
||||
|
||||
if input_ids.shape[-1] > self.max_length:
|
||||
raise ValueError(
|
||||
f"TOPReward input length {input_ids.shape[-1]} exceeds max_length "
|
||||
f"{self.max_length}; lower `max_frames` or raise `max_length`."
|
||||
)
|
||||
|
||||
labels = torch.full_like(input_ids, -100)
|
||||
labels[:, -1] = input_ids[:, -1]
|
||||
result["labels"] = labels
|
||||
return result
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"vlm_name": self.vlm_name,
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
"fps": self.fps,
|
||||
"prompt_prefix": self.prompt_prefix,
|
||||
"prompt_suffix_template": self.prompt_suffix_template,
|
||||
"add_chat_template": self.add_chat_template,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
|
||||
def make_topreward_pre_post_processors(
|
||||
config: TOPRewardConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs TOPReward's
|
||||
encoder (which tokenises the full prompt and emits ``labels``), and
|
||||
moves everything to the configured device. The postprocessor is
|
||||
the identity since TOPReward outputs a single reward tensor.
|
||||
"""
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TOPRewardEncoderProcessorStep(
|
||||
vlm_name=config.vlm_name,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
fps=config.fps,
|
||||
prompt_prefix=config.prompt_prefix,
|
||||
prompt_suffix_template=config.prompt_suffix_template,
|
||||
add_chat_template=config.add_chat_template,
|
||||
max_length=config.max_input_length,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
20
src/lerobot/robots/bi_rebot_b601_follower/__init__.py
Normal file
20
src/lerobot/robots/bi_rebot_b601_follower/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_rebot_b601_follower import BiRebotB601Follower
|
||||
from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
|
||||
|
||||
__all__ = ["BiRebotB601Follower", "BiRebotB601FollowerConfig"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user