Compare commits

..

15 Commits

Author SHA1 Message Date
Pepijn
d55b581ca1 fix(language): address review — tools accessor, motion docs, conditional collate
* **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo`
  had no `tools` field, so `from_dict` silently dropped the key (it
  warned about unknown fields then discarded them) and the property
  always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None`
  to the dataclass; `to_dict()` drops it when unset so existing
  datasets keep a clean `info.json`. Fixed the accessor to read
  `self.info.tools` (the previous `.get(...)` would have raised
  AttributeError on the dataclass anyway). Added regression tests:
  fallback when absent, round-trip from disk, and round-trip
  through `DatasetInfo.from_dict` / `to_dict`.

* **`motion` is not view-dependent — fix the docs.** The mdx claimed
  rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES
  = {"vqa", "trace"}` and the validator agrees: motion primitives are
  joint/Cartesian-frame, not pixel-space. Updated both call-out
  paragraphs in `language_and_recipes.mdx`.

* **Conditional `collate_fn` swap.** Added `meta.has_language_columns`
  and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it,
  so non-language datasets keep PyTorch's `default_collate`. Also
  added a pass-through test in `test_collate.py` that asserts on a
  plain tensor batch the custom collate matches `default_collate`
  key-for-key, plus a test for the `None`-sample drop path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:51:06 +02:00
Pepijn
24d2ffe3c6 fix(language): keep base install green — drop processor re-export, gate dataset-extra tests
`lerobot.processor` re-exported `RenderMessagesStep` at the package
level, so importing anything from `lerobot.processor` pulled in
`lerobot.datasets.language` → `lerobot.datasets/__init__.py` →
`require_package("datasets")`, which fails in the Tier 1 base install
that intentionally omits the `[dataset]` extra. The chain bricked
collection for unrelated suites (`tests/policies/pi0_pi05/...`,
`tests/envs/...`, etc.).

* Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The
  only consumer (the test) already imports from the submodule.
  Document the deliberate omission in the module docstring.
* Add `pytest.importorskip("datasets", ...)` (and `pandas` where
  needed) at the top of the four PR-added tests that exercise the
  language stack:
  - tests/datasets/test_language.py
  - tests/datasets/test_language_render.py
  - tests/processor/test_render_messages_processor.py
  - tests/utils/test_collate.py

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:12:54 +02:00
Pepijn
789f29aa56 chore: fix CI — collapse short ValueError to one line, refresh uv.lock
* `ruff format` on CI (newer version) wants the short `camera=None`
  ValueError on a single line.
* `uv.lock` was stale relative to `pyproject.toml`'s `datasets>=4.7.0`
  pin (and picked up upstream `s390x` marker fixes for cuda packages).
  CI runs `uv sync --locked` which rejected the divergence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:05:42 +02:00
Pepijn
a356b12c41 fix(language): always raise on ambiguous resolver matches
`_select_one` previously skipped its ambiguity check whenever any of
`role`/`tool_name`/`camera` was set, on the assumption that the caller
had already pinned down a unique row. That left a real ambiguity hole
for VQA: with two cameras emitting `(vqa, assistant)` at the same
frame, `emitted_at(..., role="assistant")` silently picked the first
sorted row instead of telling the recipe to add `camera=...`. The
existing `test_emitted_at_raises_on_ambiguous_per_camera_vqa` test
already encoded the desired behavior.

Tighten the check: any time `len(rows) > 1` we now raise with the
selectors echoed back, so users see exactly which fields they passed
and that more is needed to disambiguate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:00:45 +02:00
Pepijn
e8327b8e62 refactor(language): unify resolver dispatch and prune redundant test scaffolding
* Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`;
  only `emitted_at` actually consults events. The dispatcher in
  `_resolve_spec` now passes events conditionally.
* Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a
  single `_row_sort_key` and drop the `sort_key` parameter from
  `_select_one`. Event rows lack `timestamp` (it is implicit in the
  frame) and now default to `0.0` for sort purposes — the
  `(style, role)` tiebreaker is unchanged.
* Inline `_select_latest` into `active_at` (its only caller).
* Collapse `emitted_at`'s dual-branch into one `_select_one` call.
* Tighten `_validate_persistent_resolver` to a single
  `column_for_style(style) != LANGUAGE_PERSISTENT` check.
* Parameterize `test_per_camera_blend_renders_both_views` over the two
  cameras and factor the sub-recipe builder into `_vqa_subrecipe` so
  the test no longer hand-rolls two near-identical recipe blocks.

Net -98 LOC; behavior, public resolver names, and test expectations
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 13:15:45 +02:00
Pepijn
c450298147 Apply ruff and prettier formatting after merge
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 12:10:41 +02:00
Pepijn
5c30b14929 Merge remote-tracking branch 'origin/main' into feat/language-columns 2026-05-06 12:09:13 +02:00
Pepijn
e3e9374e2c feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools
Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so
datasets can declare which tools are available (today: just ``say``;
tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant
fills in for unannotated datasets so chat-template consumers don't
have to special-case anything.

Three pieces:

- ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS``
  constants. Single source of truth — PR 2's writer and PR 3's
  runtime tool registry will both import from here instead of
  duplicating the dict.
- ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property
  reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``.
  Returns deep-copied dicts so callers can mutate the result safely.
- ``docs/source/tools.mdx``: spec page covering the catalog, per-row
  invocations, and the three-step "how to add a new tool" workflow
  (declare schema, implement, register). Linked from the docs
  toctree under the Datasets section.

This lays the groundwork for PR 2's pipeline writing the catalog out
during annotation, and PR 3's ``src/lerobot/tools/`` package shipping
runnable implementations (one file per tool — first up:
``say.py`` wrapping Kyutai's pocket-tts).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:44:58 +02:00
Pepijn
c1a0c601e2 feat(language): task_aug style + automatic ${task} rephrasing rotation
Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.

- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
  ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
  ``language_persistent`` so PR 2 writers route it correctly.

- ``language_render.py``: ``_resolve_task`` now consults the persistent
  slice for rows of ``style="task_aug", role="user"``. When any exist
  it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
  Python's randomized hash) so an epoch sees every rephrasing of every
  episode while the same sample still resolves identically across
  reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
  no rephrasings are present, so existing datasets and unannotated runs
  keep their behaviour. Explicit ``task=`` overrides still win.

- Tests: rephrasing coverage across samples, determinism on repeat
  ``sample_idx``, fallback when persistent has no ``task_aug`` rows,
  and explicit override priority.

Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 16:45:39 +02:00
Pepijn
1ca38d9748 fix(language): drop motion from VIEW_DEPENDENT_STYLES
Motion primitives are described in robot-frame (joint / Cartesian) terms,
not pixel space, so they are camera-agnostic. Only `vqa` (event) and
`trace` (event, pixel-trajectory) are view-dependent.

The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry —
the validator, resolver, and HF feature mapping behave identically across
the two columns regardless of which styles populate `camera` today —
but persistent rows now always have `camera=None` in practice.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 10:54:12 +02:00
Pepijn
5a6aa64570 feat(language): per-camera tagging on view-dependent styles
Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.

- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
  to both Arrow struct types and the HF datasets feature mappings;
  introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
  `is_view_dependent_style` and `validate_camera_field` helpers (camera
  required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
  resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
  `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
  VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
  Without a `camera` filter, multi-row matches keep raising the existing
  ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
  `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
  the matching image block), keeping the original 0.20 budget and
  documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
  `is_view_dependent_style`, `validate_camera_field` (both required and
  forbidden directions), per-camera `emitted_at` filtering, and the
  ambiguity error when two cameras emit `(vqa, assistant)` at the same
  timestamp without a `camera=` filter. RenderMessagesStep + dataset
  passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
  the per-camera resolver pattern, and the canonical recipe convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 10:48:17 +02:00
Pepijn
0b06790da0 feat(language): add motion (persistent) and trace (event-only) styles
Promote the previously-reserved motion/trace styles to first-class core
styles. motion routes to language_persistent (it tracks robot state over
time); trace routes to language_events (single-moment annotations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:21:49 +02:00
Pepijn
b43dc39ba4 Add docstrings to all new helpers; revert uv.lock
Covers private helpers in recipe.py, language.py, language_render.py,
and render_messages_processor.py. Also reverts uv.lock to main (it was
re-generated by `uv run` during local checks).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:15:03 +02:00
Pepijn
2b71221194 Address review: split persistent/event schemas, drop event timestamps
- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 13:38:23 +02:00
Pepijn
8833d735a1 Add extensive language support 2026-04-27 10:56:32 +02:00
68 changed files with 2930 additions and 4889 deletions

View File

@@ -382,7 +382,6 @@ jobs:
--policy.path=\"\$ROBOTWIN_POLICY\" \
--env.type=robotwin \
--env.task=\"\$ROBOTWIN_TASKS\" \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -483,7 +482,6 @@ jobs:
--policy.path=lerobot/smolvla_robocasa \
--env.type=robocasa \
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -695,7 +693,6 @@ jobs:
--env.task=\"\$ROBOMME_TASKS\" \
--env.dataset_split=test \
--env.task_ids=[0] \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -803,7 +800,6 @@ jobs:
--env.type=libero_plus \
--env.task=\"\$LIBERO_PLUS_SUITE\" \
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -904,8 +900,6 @@ jobs:
--policy.path=lerobot/smolvla_vlabench \
--env.type=vlabench \
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
--env.episode_length=50 \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \

View File

@@ -19,8 +19,8 @@ on:
workflow_dispatch:
# Runs at 02:00
# schedule:
# - cron: "0 2 * * *"
schedule:
- cron: "0 2 * * *"
env:
CLOSE_ISSUE_MESSAGE: >

View File

@@ -35,7 +35,7 @@ USER root
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-nvcc-12-6 cuda-cudart-dev-12-6 \
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
libvulkan1 vulkan-tools \
&& mkdir -p /usr/share/vulkan/icd.d \
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \

View File

@@ -31,8 +31,10 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: language_and_recipes
title: Language Columns and Recipes
- local: tools
title: Tools
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
@@ -47,10 +49,6 @@
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: evo1
title: EVO1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla

View File

@@ -1,277 +0,0 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation

View File

@@ -1,168 +0,0 @@
# EO-1
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
## Model Overview
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
<img
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
alt="An overview of EO-1"
width="85%"
/>
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
### What the LeRobot Integration Covers
- Standard `policy.type=eo1` configuration through LeRobot
- Qwen2.5-VL image and text preprocessing through policy processors
- Continuous flow-matching action prediction
- Checkpoint save/load through LeRobot policy APIs
- Training with `lerobot-train` and evaluation with `lerobot-eval`
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install EO-1 dependencies by running:
```bash
pip install -e ".[eo1]"
```
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
```bash
pip install -e ".[eo1,libero]"
```
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
## Data Requirements
EO-1 expects a LeRobot dataset with:
- At least one visual observation, for example `observation.images.image`
- `observation.state`
- `action`
- A language task instruction through the dataset `task` field
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
## Usage
To use EO-1 in a LeRobot configuration, specify the policy type as:
```python
policy.type=eo1
```
By default, a new EO-1 policy initializes its backbone from:
```python
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
```
Once a LeRobot-format EO-1 checkpoint is available, load it with:
```python
policy.path=your-org/your-eo1-checkpoint
```
## Training
### Training Command Example
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.type=eo1 \
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
--policy.dtype=bfloat16 \
--policy.attn_implementation=sdpa \
--policy.gradient_checkpointing=false \
--output_dir=./outputs/eo1_training \
--job_name=eo1_training \
--steps=300000 \
--batch_size=16 \
--policy.device=cuda
```
### Key Training Parameters
| Parameter | Default | Description |
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
| `policy.max_state_dim` | `32` | State padding dimension |
| `policy.max_action_dim` | `32` | Action padding dimension |
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
## Evaluation
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
```bash
lerobot-eval \
--policy.path=your-org/your-eo1-checkpoint \
--env.type=libero \
--env.task=libero_object \
--eval.batch_size=1 \
--eval.n_episodes=20
```
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
```bash
lerobot-eval \
--policy.path=your-org/your-eo1-checkpoint \
--env.type=libero \
--env.task=libero_object \
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
```
## Configuration Notes
### Image Processing
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
### State and Action Dimensions
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
### Attention Backend
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
## References
- [EO-1 project](https://github.com/EO-Robotics/EO1)
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
## Citation
```bibtex
@article{eo1,
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
journal={arXiv preprint},
year={2025},
url={https://arxiv.org/abs/2508.21112}
}
```
## License
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.

View File

@@ -1,132 +0,0 @@
# EVO1
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
## Model Overview
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
### What the LeRobot Integration Covers
- Standard `policy.type=evo1` configuration through LeRobot
- InternVL3 image/text embedding with optional FlashAttention fallback
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
- Continuous flow-matching action prediction
- Checkpoint save/load through LeRobot policy APIs
- Training with `lerobot-train` and evaluation with standard policy inference APIs
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install EVO1 dependencies:
```bash
pip install -e ".[evo1]"
```
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
## Data Requirements
EVO1 expects a LeRobot dataset with:
- One to `policy.max_views` visual observations, for example `observation.images.image`
- `observation.state`
- `action`
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
## Usage
To use EVO1 in a LeRobot configuration, specify:
```python
policy.type=evo1
```
By default, a new EVO1 policy initializes its VLM from:
```python
policy.vlm_model_name=OpenGVLab/InternVL3-1B
```
Once a LeRobot-format EVO1 checkpoint is available, load it with:
```python
policy.path=your-org/your-evo1-checkpoint
```
## Training
### Stage 1
Stage 1 freezes the VLM and trains the action head:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.type=evo1 \
--policy.training_stage=stage1 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=5000 \
--output_dir=./outputs/evo1_stage1
```
### Stage 2
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
```bash
lerobot-train \
--dataset.repo_id=your_org/your_dataset \
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
--policy.training_stage=stage2 \
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
--policy.device=cuda \
--policy.chunk_size=50 \
--policy.n_action_steps=50 \
--policy.max_state_dim=24 \
--policy.max_action_dim=24 \
--policy.optimizer_lr=1e-5 \
--batch_size=4 \
--steps=80000 \
--output_dir=./outputs/evo1_stage2
```
### Key Training Parameters
| Parameter | Default | Description |
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
| `policy.max_state_dim` | `24` | State padding dimension |
| `policy.max_action_dim` | `24` | Action padding dimension |
| `policy.task_field` | `task` | Batch field used as the language prompt |
## References
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
## License
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.

View File

@@ -0,0 +1,147 @@
# Language columns and recipes
Most LeRobot datasets ship with a single `task` string per episode — fine for
short, single-instruction skills, but not enough for the longer-horizon,
multi-modal robot policies the field is moving toward (high-level planning,
memory, interjections, VQA, tool use). To support those policies without
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
language columns and a small recipe layer that turns those rows into
chat-style training samples on the fly.
The design splits cleanly into three layers:
1. **Data in the dataset** — language annotations stored next to frames in
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
and `language_events`). Datasets without these columns keep their existing
behavior.
2. **Recipe** — a YAML file that declares which annotation rows to bind and
how to lay them out as chat turns (`role`, `content`, optional images,
optional tool calls). Recipes are pure config; no Python required to add a
new one.
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
recipe against the per-frame annotations and emits HF-style `messages` plus
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
that policy processors consume.
This page describes each layer in turn.
## Layer 1 — language columns in the dataset
The two optional columns live next to frame data in
`data/chunk-*/file-*.parquet`:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape (event rows omit `timestamp` because the
frame the row sits on already provides it):
```text
role: string
content: string | null
style: string | null
timestamp: float64 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null
```
The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
the matching `observation.images.*` feature key. Rows of every other style —
including `motion`, which describes robot-frame primitives in joint / Cartesian
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
### Architecture
The language stack itself has three internal modules backing layer 1:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
### Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
### View-dependent resolution
For view-dependent styles (`vqa` and `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two
matches and raise an ambiguity error. Recipes consume each camera through its
own binding plus a matching image block, e.g.
```yaml
ask_vqa_top:
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- { type: image, feature: observation.images.top }
- { type: text, text: "${vqa_query}" }
- {
role: assistant,
content: "${vqa}",
stream: high_level,
target: true,
if_present: vqa,
}
```
Add one such sub-recipe per camera the dataset records.
## Layer 2 — recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
declare which annotation rows to pull (via `bindings`) and how to compose them
into chat turns (`messages`).
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
time, exactly one branch is selected deterministically from the sample index,
so different frames train different objectives (e.g. memory updates vs.
low-level execution vs. VQA) without any Python wiring.
## Layer 3 — training format
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.

200
docs/source/tools.mdx Normal file
View File

@@ -0,0 +1,200 @@
# Tools
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
emit structured invocations like `say(text="OK, starting now")` that the
runtime dispatches to a real implementation (TTS, controller, logger, …).
This page covers:
1. Where the tool catalog lives.
2. How the annotation pipeline produces tool-call atoms.
3. How to add your own tool.
## Where tools are declared
Two layers.
**The catalog** — a list of OpenAI-style function schemas — lives at
`meta/info.json["tools"]` on each dataset. Example:
```json
{
"features": { "...": "..." },
"tools": [
{
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak."
}
},
"required": ["text"]
}
}
}
]
}
```
Read it via the dataset metadata accessor:
```python
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
tools = meta.tools # list[dict] — OpenAI tool schemas
```
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
single-entry list with the canonical `say` schema. So unannotated
datasets and chat-template consumers keep working without any
configuration:
```python
prompt_str = tokenizer.apply_chat_template(
sample["messages"],
tools=meta.tools, # works either way
add_generation_prompt=False,
tokenize=False,
)
```
**The implementations** — runnable Python — live under
`src/lerobot/tools/`, one file per tool. The canonical `say`
implementation wraps Kyutai's pocket-tts model.
## Per-row tool _invocations_
The catalog above describes _what can be called_. The actual _call_ — the
function name plus the argument values — is stored per-row, on the
assistant atoms in `language_events`:
```python
{
"role": "assistant",
"content": null,
"style": null,
"timestamp": 12.4,
"camera": null,
"tool_calls": [
{ "type": "function",
"function": { "name": "say", "arguments": { "text": "On it." } } }
]
}
```
Recipes splice these into rendered messages via `tool_calls_from`:
```yaml
user_interjection_response:
bindings:
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- { role: user, content: "${task}", stream: high_level }
- {
role: assistant,
content: "${current_plan}",
stream: high_level,
target: true,
tool_calls_from: speech,
}
```
The model's training target is one assistant turn that carries both the
plan text _and_ the `say` tool call. At inference, the runtime parses
the generated text back into structured `tool_calls` and dispatches to
the matching implementation.
## How to add your own tool
Three steps. Concrete example: a `record_observation` tool the policy
can call to capture an extra observation outside the regular control
loop.
### Step 1 — declare the schema
Add an entry under `meta/info.json["tools"]`. Either edit the file
directly on disk _before_ running the annotation pipeline (it'll be
preserved) or hand it to `lerobot-annotate` via a config flag.
```json
{
"tools": [
{ "type": "function", "function": { "name": "say", "...": "..." } },
{
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a high-resolution still image for the user.",
"parameters": {
"type": "object",
"properties": {
"label": {
"type": "string",
"description": "Short label for the saved image."
}
},
"required": ["label"]
}
}
}
]
}
```
The schema follows OpenAI's function-calling convention exactly, so the
chat template can render it natively.
### Step 2 — implement the call
Create `src/lerobot/tools/record_observation.py`:
```python
from .base import Tool
from typing import Any
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
class RecordObservationTool:
name = "record_observation"
schema = RECORD_OBSERVATION_SCHEMA
def __init__(self, schema: dict | None = None, output_dir: str = "."):
self.output_dir = output_dir
def call(self, arguments: dict) -> str:
label = arguments["label"]
# ... save the latest camera frame to <output_dir>/<label>.png ...
return f"saved {label}.png"
```
One file per tool keeps dependencies isolated — `record_observation`
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
only the tools they need avoid heavy transitive deps.
### Step 3 — register it
Add to `src/lerobot/tools/registry.py`:
```python
from .record_observation import RecordObservationTool
TOOL_REGISTRY["record_observation"] = RecordObservationTool
```
That's it. At runtime `get_tools(meta)` looks up each schema in
`meta.tools`, instantiates the matching registered class, and returns
a name → instance dict the dispatcher can route into.
If you want to use a tool _without_ writing an implementation (e.g. for
training-time chat-template formatting only), step 1 alone is enough —
the model still learns to _generate_ the call. Steps 2 and 3 are only
needed to actually _execute_ it at inference.

View File

@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Core ML
"torch>=2.7,<2.13.0",
"torchvision>=0.22.0,<0.28.0",
"torch>=2.7,<2.11.0",
"torchvision>=0.22.0,<0.26.0",
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"opencv-python-headless>=4.9.0,<4.14.0",
"Pillow>=10.0.0,<13.0.0",
@@ -95,11 +95,11 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.0.0,<5.0.0",
"datasets>=4.7.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
"torchcodec>=0.3.0,<0.13.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
"jsonlines>=4.0.0,<5.0.0",
]
training = [
@@ -194,8 +194,6 @@ groot = [
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
evo1 = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -259,7 +257,6 @@ all = [
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[evo1]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -336,7 +333,6 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
[tool.ruff.lint.isort]
combine-as-imports = true

View File

@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -43,7 +44,10 @@ __all__ = [
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
]

View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
if self.stream is not None and self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(_PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(_PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)

View File

@@ -256,9 +256,7 @@ class TrainPipelineConfig(HubMixin):
) from e
cli_args = kwargs.pop("cli_args", [])
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
if config_file is not None and config_file.endswith(".json"):
if config_file is not None:
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)

View File

@@ -37,6 +37,14 @@ from .dataset_tools import (
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
@@ -53,10 +61,15 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"add_features",
@@ -66,6 +79,7 @@ __all__ = [
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",

View File

@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
if features[key]["dtype"] in {"string", "language"}:
continue
if features[key]["dtype"] in ["image", "video"]:

View File

@@ -34,7 +34,6 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_stats,
@@ -175,7 +174,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -318,6 +316,39 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def has_language_columns(self) -> bool:
"""Return ``True`` if the dataset declares any language column.
Used to gate language-aware code paths (collate, render step) so
unannotated datasets keep PyTorch's default collate behavior.
"""
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
return any(col in self.features for col in LANGUAGE_COLUMNS)
@property
def tools(self) -> list[dict]:
"""OpenAI-style tool schemas declared by this dataset.
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
can mutate the result safely. Falls back to
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
``say`` schema) when the dataset doesn't declare any — that way
unannotated datasets and chat-template consumers
(``apply_chat_template(messages, tools=meta.tools)``) keep
working out of the box.
Implementations live under :mod:`lerobot.tools` (one file per
tool); see ``docs/source/tools.mdx`` for the authoring guide.
"""
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
declared = self.info.tools
if declared:
return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -633,7 +664,6 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(

View File

@@ -295,9 +295,4 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item

View File

@@ -22,6 +22,12 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -46,7 +52,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -242,6 +254,8 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return ""
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")

View File

@@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
@@ -186,14 +185,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -265,11 +256,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in {"language_persistent", "language_events"}:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
elif first_item is None or isinstance(first_item, dict):
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -305,7 +298,11 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in [
"task",
"language_persistent",
"language_events",
]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item

View File

@@ -0,0 +1,234 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {
"subtask",
"plan",
"memory",
"motion",
"interjection",
"vqa",
"trace",
"task_aug",
}
EXTENDED_STYLES = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
# is intentionally NOT in this set: motion primitives are described in
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
# view-dependent. The ``camera`` field nevertheless lives on
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
# behave symmetrically across the two columns; persistent rows simply
# always have ``camera=None`` in practice today.
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def is_view_dependent_style(style: str | None) -> bool:
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
return style in VIEW_DEPENDENT_STYLES
def validate_camera_field(style: str | None, camera: str | None) -> None:
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
a non-view-dependent style carries one. Pipeline writers and the validator
should call this on every emitted row.
"""
if is_view_dependent_style(style):
if not camera:
raise ValueError(
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
f"field referencing an 'observation.images.*' feature key."
)
elif camera is not None:
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
# --- Tool registry --------------------------------------------------------
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
# of OpenAI-style function schemas. The runtime / training stack reads them
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
# fallback when the dataset doesn't declare any). Implementations live
# under :mod:`lerobot.tools` (one file per tool); see
# ``docs/source/tools.mdx`` for the authoring guide.
SAY_TOOL_SCHEMA: dict = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
"""Canonical schema for the ``say`` tool emitted by the steerable
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
writer, PR 3's runtime tool registry, and the dataset visualizer all
import this constant rather than duplicating the dict."""
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
chat-template consumers (``apply_chat_template(messages, tools=...)``)
keep working out of the box."""
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")

View File

@@ -0,0 +1,532 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
from .language import LANGUAGE_PERSISTENT, column_for_style
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
``camera`` selector. Only valid for persistent styles.
"""
_validate_persistent_resolver("active_at", style)
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) <= t
]
if not matches:
return None
latest_ts = max(_timestamp(row) for row in matches)
return _select_one(
[row for row in matches if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
camera=camera,
)
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``. ``camera``
filters by the row's ``camera`` field — required to disambiguate when
multiple view-dependent rows share ``(t, role)`` across cameras.
"""
if column_for_style(style) == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) == t
]
else:
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions before the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions after the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(
task: str | None,
dataset_ctx: Any | None,
*,
persistent: Sequence[LanguageRow] = (),
sample_idx: int = 0,
) -> str | None:
"""Return the task string for ``sample_idx``.
Resolution order:
1. Explicit ``task`` override (caller-supplied) wins.
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
deterministically pick one by ``sample_idx`` so each frame of an
episode rotates through the available rephrasings across an epoch.
This realizes Xiao 2022 / CAST-style task-prompt diversity without
changing ``meta/tasks.parquet`` and without forcing recipes to opt
in: ``${task}`` automatically picks a rephrasing when one exists,
and falls back to the canonical task otherwise. Recipes that want
the literal canonical task can override the binding.
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
backed by ``meta/tasks.parquet``).
"""
if task is not None:
return task
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
return str(chosen)
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
if name == "emitted_at":
return emitted_at(t, persistent=persistent, events=events, **kwargs)
if name == "active_at":
return active_at(t, persistent=persistent, **kwargs)
if name == "nth_prev":
return nth_prev(t, persistent=persistent, **kwargs)
if name == "nth_next":
return nth_next(t, persistent=persistent, **kwargs)
raise ValueError(f"Unknown language resolver: {name!r}")
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return _PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
for idx, stream in enumerate(streams):
if stream is None:
raise ValueError(f"Rendered message {idx} has no stream.")
def _nth_relative(
name: str,
t: float,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(name, style)
if abs(offset) < 1:
raise ValueError(f"{name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
key=_row_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{name} requires a persistent style.")
if column_for_style(style) != LANGUAGE_PERSISTENT:
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
and (camera is None or row.get("camera") == camera)
]
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the resolver is ambiguous.
Multiple matches always raise — even when the caller already passed
some selectors — because remaining ambiguity means the data has
several rows that look identical to the resolver and the caller
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
rows shared across cameras).
"""
if not rows:
return None
if len(rows) > 1:
raise ValueError(
f"Ambiguous resolver for style={style!r} role={role!r} "
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
f"Add a selector that distinguishes them."
)
return rows[0]
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Stable sort key for both persistent and event rows.
Event rows lack ``timestamp`` (it is implicit in the frame), so default
to ``0.0`` — within a single frame all event rows share the same sort
bucket and are tiebroken by ``(style, role)``.
"""
timestamp = row.get("timestamp")
ts = (
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
)
return (ts, row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
value = row["timestamp"]
return float(value.item() if hasattr(value, "item") else value)
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized

View File

@@ -88,7 +88,6 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -130,6 +129,9 @@ class DatasetInfo:
# Optional metadata
robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict)
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
tools: list[dict] | None = None
def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation
@@ -151,11 +153,15 @@ class DatasetInfo:
"""Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them.
Drops ``tools`` when unset so existing datasets keep a clean
``info.json``.
"""
d = dataclasses.asdict(self)
for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"])
if d.get("tools") is None:
d.pop("tools", None)
return d
@classmethod

View File

@@ -16,8 +16,6 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
@@ -41,10 +39,8 @@ __all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"Evo1Config",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
"PI0Config",
"PI0FastConfig",
"PI05Config",

View File

@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 64
n_action_steps: int = 32
horizon: int = 16
n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
use_group_norm: bool = False
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = True
use_separate_rgb_encoder_per_camera: bool = False
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5

View File

@@ -1 +0,0 @@
../../../../docs/source/eo1.mdx

View File

@@ -1,7 +0,0 @@
#!/usr/bin/env python
from .configuration_eo1 import EO1Config
from .modeling_eo1 import EO1Policy
from .processor_eo1 import make_eo1_pre_post_processors
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]

View File

@@ -1,193 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLTextConfig,
Qwen2_5_VLVisionConfig,
)
else:
Qwen2_5_VLConfig = None
Qwen2_5_VLTextConfig = None
Qwen2_5_VLVisionConfig = None
@PreTrainedConfig.register_subclass("eo1")
@dataclass
class EO1Config(PreTrainedConfig):
"""Configuration for native EO1 policy integration in LeRobot."""
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
vlm_config: dict | None = None
# Vision processor settings.
image_min_pixels: int | None = 64 * 28 * 28
image_max_pixels: int | None = 128 * 28 * 28
use_fast_processor: bool = False
# Execution and action horizon.
n_obs_steps: int = 1
chunk_size: int = 8
n_action_steps: int = 8
# State/action padding to match EO1 flow head dimensionality.
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching sampling.
num_denoise_steps: int = 10
num_action_layers: int = 2
action_act: str = "linear"
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
supervise_padding_action_dims: bool = True
supervise_padding_actions: bool = True
# Policy-level dtype request for the Qwen backbone.
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
# The EO1 flow-matching head still keeps its own parameters in fp32.
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
force_fp32_autocast: bool = True
# Optional attention backend request passed through to the Qwen backbone.
# Common values: None, "eager", "sdpa", "flash_attention_2".
attn_implementation: str | None = None
# Training settings.
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.1
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
# Note: These will auto-scale if --steps < scheduler_decay_steps
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 0.0
def __post_init__(self):
super().__post_init__()
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
# Populate the serialized backbone config only when the caller did not provide one.
if self.vlm_config is None:
require_package("transformers", extra="eo1")
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
@property
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
require_package("transformers", extra="eo1")
config_dict = deepcopy(self.vlm_config)
if self.attn_implementation is not None:
config_dict["attn_implementation"] = self.attn_implementation
return Qwen2_5_VLConfig(**config_dict)
@property
def text_config(self) -> Qwen2_5_VLTextConfig:
return self.vlm_backbone_config.text_config
@property
def vision_config(self) -> Qwen2_5_VLVisionConfig:
return self.vlm_backbone_config.vision_config
def validate_features(self) -> None:
"""Validate and set up EO1 input and output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"EO1 policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
self.output_features[ACTION] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,620 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import logging
import math
from collections import deque
from typing import TYPE_CHECKING, Any
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torch.utils.checkpoint
from torch import Tensor
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from transformers.utils import torch_compilable_check
else:
ACT2FN = None
Qwen2_5_VLForConditionalGeneration = None
torch_compilable_check = None
logger = logging.getLogger(__name__)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] >= new_dim:
return vector
return F.pad(vector, (0, new_dim - vector.shape[-1]))
class EO1Policy(PreTrainedPolicy):
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
config_class = EO1Config
name = "eo1"
def __init__(self, config: EO1Config, **kwargs):
require_package("transformers", extra="eo1")
super().__init__(config)
config.validate_features()
self.config = config
if config.pretrained_path is None:
# Initialize from pretrained VLM
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config.vlm_base,
dtype=config.dtype,
attn_implementation=config.attn_implementation,
)
else:
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
config.vlm_backbone_config,
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
)
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
if config.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.to(config.device)
self.reset()
def reset(self):
self._action_queue = deque(maxlen=self.config.n_action_steps)
@staticmethod
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
return {key: value for key, value in batch.items() if key not in excluded_keys}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
state = self.prepare_state(batch[OBS_STATE])
actions = self.prepare_action(batch[ACTION])
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
loss = self.model(states=state, action=actions, **model_inputs)
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
states = self.prepare_state(batch[OBS_STATE])
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
original_action_dim = self.config.output_features[ACTION].shape[0]
return actions[:, :, :original_action_dim]
def prepare_state(self, state: Tensor) -> Tensor:
return pad_vector(state, self.config.max_state_dim)
def prepare_action(self, action: Tensor) -> Tensor:
return pad_vector(action, self.config.max_action_dim)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def get_optim_params(self) -> dict:
return self.parameters()
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
return torch.float32
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,)).to(device)
class EO1VisionActionProjector(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module."""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 2,
activation_layer: str = "linear",
bias: bool = True,
device: Any = None,
dtype: torch.dtype = torch.float32,
):
layers = []
in_dim = in_channels
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
layers.append(ACT2FN[activation_layer])
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
super().__init__(*layers)
@property
def dtype(self):
return self[0].weight.dtype
class EO1VisionFlowMatchingModel(nn.Module):
def __init__(
self,
config: EO1Config,
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
):
require_package("transformers", extra="eo1")
super().__init__()
self.config = config
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
self.vlm_backbone = vlm_backbone
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
max_state_dim = config.max_state_dim
max_action_dim = config.max_action_dim
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
self.action_out_proj = EO1VisionActionProjector(
self.hidden_size,
max_action_dim,
config.num_action_layers,
config.action_act,
dtype=torch.float32,
)
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
self.gradient_checkpointing_enabled = False
def get_input_embeddings(self):
return self.vlm_backbone.get_input_embeddings()
def flow_head_autocast_context(self):
if self.config.force_fp32_autocast:
return torch.autocast(
device_type=self.state_proj.weight.device.type,
enabled=False,
)
return contextlib.nullcontext()
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
self.gradient_checkpointing_enabled = True
self.vlm_backbone.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
self.gradient_checkpointing_enabled = False
self.vlm_backbone.gradient_checkpointing_disable()
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
def _apply_checkpoint(self, func, *args, **kwargs):
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
)
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device)
def get_placeholder_mask(
self,
input_ids: torch.LongTensor | None,
inputs_embeds: torch.FloatTensor | None,
state_features: torch.FloatTensor | None = None,
action_features: torch.FloatTensor | None = None,
*,
state_token_id: int,
action_token_id: int,
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
if input_ids is None:
special_state_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_state_mask = special_state_mask.all(-1)
special_action_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_action_mask = special_action_mask.all(-1)
else:
special_state_mask = input_ids == state_token_id
special_action_mask = input_ids == action_token_id
n_state_tokens = special_state_mask.sum()
special_state_mask = (
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
)
if state_features is not None:
torch_compilable_check(
inputs_embeds[special_state_mask].numel() == state_features.numel(),
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
)
n_action_tokens = special_action_mask.sum()
special_action_mask = (
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
)
if action_features is not None:
torch_compilable_check(
inputs_embeds[special_action_mask].numel() == action_features.numel(),
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
)
return special_state_mask, special_action_mask
def embed_prefix(
self,
input_ids: torch.LongTensor,
states: torch.Tensor,
*,
state_token_id: int,
action_token_id: int,
) -> torch.FloatTensor:
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
# Get the input embeddings for the input IDs
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
return self.get_input_embeddings()(input_ids)
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
# Project the states to the hidden size
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
states = states.to(dtype=self.state_proj.weight.dtype)
return self.state_proj(states)
state_embs = self._apply_checkpoint(state_proj_func, states)
state_mask, _ = self.get_placeholder_mask(
input_ids,
inputs_embeds,
state_features=state_embs,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
return inputs_embeds
def embed_suffix(
self,
timestep: torch.Tensor,
noisy_actions: torch.Tensor,
) -> torch.FloatTensor:
"""Embed the suffix"""
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
return self.action_in_proj(noisy_actions)
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
time_embs = create_sinusoidal_pos_embedding(
timestep,
self.hidden_size,
min_period=self.config.min_period,
max_period=self.config.max_period,
device=action_embs.device,
)
time_embs = time_embs.to(dtype=action_embs.dtype)
time_embs = time_embs[:, None, :].expand_as(action_embs)
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
action_time_embs = self.action_time_mlp_in(action_time_embs)
action_time_embs = F.silu(action_time_embs)
return self.action_time_mlp_out(action_time_embs)
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
return action_time_embs
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
mm_token_type_ids: torch.IntTensor | None = None,
states: torch.FloatTensor | None = None,
action: torch.FloatTensor | None = None,
action_is_pad: torch.BoolTensor | None = None,
*,
state_token_id: int,
action_token_id: int,
**kwargs,
) -> Tensor:
"""Run the EO1 training forward pass and compute the flow-matching loss."""
# 1. Build the EO1 prefix with state placeholders resolved.
inputs_embeds = self.embed_prefix(
input_ids,
states=states,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
# 2. Sample the diffusion target and replace the action placeholders.
time = self.sample_time(action.shape[0], inputs_embeds.device)
noise = self.sample_noise(action.shape, inputs_embeds.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * action
u_t = noise - action
action_time_embs = self.embed_suffix(time, x_t)
_, action_mask = self.get_placeholder_mask(
input_ids,
inputs_embeds,
action_features=action_time_embs,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
# 3. Optionally drop padded action tokens from backbone attention.
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if not self.config.supervise_padding_actions:
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
action_token_mask = action_mask[..., 0]
action_padding_mask = torch.zeros_like(action_token_mask)
action_padding_mask = action_padding_mask.masked_scatter(
action_token_mask,
action_is_pad.reshape(-1),
)
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
# 4. Run the Qwen backbone on the fused EO1 sequence.
def vlm_forward_func(
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None,
inputs_embeds: torch.FloatTensor,
pixel_values: torch.Tensor | None,
image_grid_thw: torch.LongTensor | None,
mm_token_type_ids: torch.IntTensor | None,
) -> torch.FloatTensor:
outputs = self.vlm_backbone.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
mm_token_type_ids=mm_token_type_ids,
use_cache=False,
output_hidden_states=False,
return_dict=True,
)
return outputs.last_hidden_state
hidden_states = self._apply_checkpoint(
vlm_forward_func,
input_ids,
attention_mask,
inputs_embeds,
pixel_values,
image_grid_thw,
mm_token_type_ids,
)
action_hidden_states = hidden_states[action_mask[..., 0]]
# 5. Project the action-token hidden states back to the flow target space.
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
with self.flow_head_autocast_context():
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
return self.action_out_proj(action_hidden_states)
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
losses = F.mse_loss(u_t, v_t, reduction="none")
# 6. Apply the configured supervision mask and reduce the loss.
if not self.config.supervise_padding_action_dims:
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[..., :original_action_dim]
if not self.config.supervise_padding_actions:
losses = losses[~action_is_pad]
return losses.mean()
@torch.no_grad()
def sample_actions(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
pixel_values: torch.Tensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
mm_token_type_ids: torch.IntTensor | None = None,
states: torch.Tensor | None = None,
*,
state_token_id: int,
action_token_id: int,
**kwargs,
) -> Tensor:
"""Sample actions from the model."""
if states is None:
raise ValueError("states are required for EO1 action sampling.")
if mm_token_type_ids is None:
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
# 1. Resolve the left-padded rollout prompt and locate the action span.
chunk_size = self.config.chunk_size
inputs_embeds = self.embed_prefix(
input_ids,
states=states,
state_token_id=state_token_id,
action_token_id=action_token_id,
).clone()
_, action_placeholder_mask = self.get_placeholder_mask(
input_ids,
inputs_embeds,
state_token_id=state_token_id,
action_token_id=action_token_id,
)
action_mask = action_placeholder_mask[..., 0]
token_counts = action_mask.sum(dim=1)
if not torch.all(token_counts == chunk_size):
raise ValueError(
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
)
if action_mask.ne(action_mask[:1]).any():
raise ValueError(
"Batch inference expects all samples to share the same action token mask after left padding."
)
act_start = int(action_mask[0].to(torch.int64).argmax().item())
act_end = act_start + self.config.chunk_size
if not torch.all(action_mask[:, act_start:act_end]):
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
act_slice = slice(act_start, act_end)
# 2. Encode the fixed prefix once and cache its KV state.
batch_size = input_ids.shape[0]
device = inputs_embeds.device
attention_mask = attention_mask.to(device)
mm_token_type_ids = mm_token_type_ids.to(device)
position_ids, _ = self.vlm_backbone.model.get_rope_index(
input_ids,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
mm_token_type_ids=mm_token_type_ids,
)
position_ids = position_ids.to(device)
outputs = self.vlm_backbone.model(
input_ids=input_ids[:, :act_start],
attention_mask=attention_mask[:, :act_start],
position_ids=position_ids[..., :act_start],
inputs_embeds=inputs_embeds[:, :act_start],
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
mm_token_type_ids=mm_token_type_ids[:, :act_start],
use_cache=True,
return_dict=True,
)
x_t = self.sample_noise(
(batch_size, chunk_size, self.config.max_action_dim),
device,
).to(dtype=self.action_in_proj.weight.dtype)
dt = -1.0 / self.config.num_denoise_steps
past_key_values = outputs.past_key_values
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
for step in range(self.config.num_denoise_steps):
time = torch.full(
(batch_size,),
1.0 + step * dt,
device=device,
dtype=torch.float32,
)
action_time_embs = self.embed_suffix(time, x_t)
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
# Keep the prefix KV cache invariant across denoising steps.
past_key_values.crop(act_start)
outputs = self.vlm_backbone.model(
attention_mask=attention_mask[:, :act_end],
past_key_values=past_key_values,
inputs_embeds=inputs_embeds[:, act_slice],
position_ids=position_ids[..., act_slice],
use_cache=True,
return_dict=True,
)
with self.flow_head_autocast_context():
hidden_states = outputs.last_hidden_state[:, :chunk_size]
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
v_t = self.action_out_proj(hidden_states)
x_t += dt * v_t.reshape(x_t.shape)
return x_t

View File

@@ -1,282 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.policies.eo1.configuration_eo1 import EO1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
else:
Qwen2_5_VLProcessor = None
SYSTEM_MESSAGE = "You are a helpful physical assistant."
# EO-1 special tokens
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
STATE_START_TOKEN = "<|state_start|>" # nosec B105
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
STATE_END_TOKEN = "<|state_end|>" # nosec B105
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
EO1_SPECIAL_TOKENS = [
ACTION_START_TOKEN,
DEFAULT_ACTION_TOKEN,
ACTION_END_TOKEN,
STATE_START_TOKEN,
DEFAULT_STATE_TOKEN,
STATE_END_TOKEN,
TASK_VLA_TOKEN,
]
@dataclass
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
chunk_size: int
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
def __post_init__(self):
# Robust JSON deserialization handling (guard empty maps).
if self.input_features:
first_val = next(iter(self.input_features.values()))
if isinstance(first_val, dict):
reconstructed = {}
for key, ft_dict in self.input_features.items():
reconstructed[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.input_features = reconstructed
self._image_keys = [
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
]
def complementary_data(self, complementary_data):
tasks = complementary_data.get("task")
if tasks is None:
raise ValueError("Task is required for EO1ConversationTemplateStep.")
observation = self.transition.get(TransitionKey.OBSERVATION)
if observation is None:
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
raise ValueError("Batch size mismatch between observation.state and task list.")
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
images = {
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
}
messages = []
for i in range(len(tasks)):
content = [
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
{
"type": "text",
"text": (
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
),
},
]
messages.append(
[
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
{"role": "user", "content": content},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
}
],
},
]
)
complementary_data["messages"] = messages
return complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step only materializes EO1-specific message objects in complementary_data.
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
feature contract change to record here.
"""
return features
def get_config(self) -> dict[str, Any]:
return {
"input_features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
},
"chunk_size": self.chunk_size,
}
@dataclass
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
image_min_pixels: int | None = 64 * 28 * 28
image_max_pixels: int | None = 128 * 28 * 28
use_fast_processor: bool = False
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
_state_token_id: int | None = field(default=None, init=False, repr=False)
_action_token_id: int | None = field(default=None, init=False, repr=False)
def __post_init__(self):
require_package("transformers", extra="eo1")
self._processor = Qwen2_5_VLProcessor.from_pretrained(
self.processor_name,
use_fast=self.use_fast_processor,
)
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
def complementary_data(self, complementary_data):
messages = complementary_data.pop("messages", None)
if messages is None:
raise ValueError("Messages are required for EO1QwenProcessorStep.")
# Rollout batches use left padding so action spans stay aligned across samples.
# Supervised batches use right padding to match standard training collation.
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
inputs = self._processor.apply_chat_template(
messages,
tokenize=True,
padding=True,
padding_side=padding_side,
min_pixels=self.image_min_pixels,
max_pixels=self.image_max_pixels,
add_generation_prompt=False,
return_dict=True,
return_tensors="pt",
)
complementary_data["input_ids"] = inputs["input_ids"]
complementary_data["pixel_values"] = inputs["pixel_values"]
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
complementary_data["attention_mask"] = inputs["attention_mask"]
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
complementary_data["state_token_id"] = self._state_token_id
complementary_data["action_token_id"] = self._action_token_id
return complementary_data
def get_config(self) -> dict[str, Any]:
return {
"processor_name": self.processor_name,
"image_min_pixels": self.image_min_pixels,
"image_max_pixels": self.image_max_pixels,
"use_fast_processor": self.use_fast_processor,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step only converts the messages to the model input format.
"""
return features
def make_eo1_pre_post_processors(
config: EO1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build pre/post processor pipelines for EO1."""
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
EO1QwenProcessorStep(
processor_name=config.vlm_base,
image_min_pixels=config.image_min_pixels,
image_max_pixels=config.image_max_pixels,
use_fast_processor=config.use_fast_processor,
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -1 +0,0 @@
../../../../docs/source/evo1.mdx

View File

@@ -1,19 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_evo1 import Evo1Config
from .modeling_evo1 import EVO1Policy
from .processor_evo1 import make_evo1_pre_post_processors
__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"]

View File

@@ -1,211 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from dataclasses import dataclass, field
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
@LRSchedulerConfig.register_subclass("evo1_exact")
@dataclass
class Evo1SchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step: int) -> float:
if current_step < self.num_warmup_steps:
return current_step / max(1, self.num_warmup_steps)
progress = (current_step - self.num_warmup_steps) / max(
1, num_training_steps - self.num_warmup_steps
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda, -1)
@PreTrainedConfig.register_subclass("evo1")
@dataclass
class Evo1Config(PreTrainedConfig):
training_stage: str = "stage1"
use_amp: bool = True
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
max_state_dim: int = 24
max_action_dim: int = 24
max_views: int = 3
image_resolution: tuple[int, int] = (448, 448)
empty_cameras: int = 0
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
vlm_num_layers: int | None = 14
vlm_dtype: str = "bfloat16"
use_flash_attn: bool = True
action_head: str = "flowmatching"
embed_dim: int = 896
hidden_dim: int = 1024
state_hidden_dim: int = 1024
num_heads: int = 8
num_layers: int = 8
dropout: float = 0.0
num_inference_timesteps: int = 32
num_categories: int = 1
return_cls_only: bool = False
enable_gradient_checkpointing: bool = True
gradient_checkpointing_use_reentrant: bool = False
finetune_vlm: bool | None = None
finetune_language_model: bool | None = None
finetune_vision_model: bool | None = None
finetune_action_head: bool | None = None
task_field: str = "task"
embodiment_id_field: str | None = None
default_embodiment_id: int = 0
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 300
drop_last: bool = True
def __post_init__(self):
super().__post_init__()
if self.training_stage not in {"stage1", "stage2"}:
raise ValueError(
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
)
if self.training_stage == "stage1":
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = True
elif self.training_stage == "stage2":
has_explicit_branch_flags = any(
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
)
if not has_explicit_branch_flags:
if self.finetune_vlm is None:
self.finetune_vlm = True
if self.finetune_language_model is None:
self.finetune_language_model = True
if self.finetune_vision_model is None:
self.finetune_vision_model = True
elif self.finetune_vlm is None:
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
if self.finetune_action_head is None:
self.finetune_action_head = True
if self.finetune_vlm is None:
self.finetune_vlm = False
if self.finetune_language_model is None:
self.finetune_language_model = False
if self.finetune_vision_model is None:
self.finetune_vision_model = False
if self.finetune_action_head is None:
self.finetune_action_head = False
branch_vlm = self.finetune_language_model or self.finetune_vision_model
if self.finetune_vlm != branch_vlm:
raise ValueError(
"Inconsistent EVO1 finetune config: "
f"finetune_vlm={self.finetune_vlm} but "
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
"When branch-level flags are used, finetune_vlm must match their effective union."
)
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
)
def validate_features(self) -> None:
if self.input_features is None:
self.input_features = {}
if self.output_features is None:
self.output_features = {}
for i in range(self.empty_cameras):
key = OBS_IMAGES + f".empty_camera_{i}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution),
)
if OBS_STATE not in self.input_features:
self.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return Evo1SchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
return [0]
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,234 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
import torch
import torch.nn as nn
from PIL import Image
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
def _cfgget(config: Any, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class EVO1(nn.Module):
def __init__(self, config: dict):
super().__init__()
self.config = config
self._device = _cfgget(config, "device", "cuda")
self.return_cls_only = _cfgget(config, "return_cls_only", False)
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
image_size = _cfgget(config, "image_size", 448)
if image_size is None:
image_resolution = _cfgget(config, "image_resolution", (448, 448))
image_size = int(image_resolution[0])
self.embedder = InternVL3Embedder(
model_name=vlm_name,
image_size=image_size,
device=self._device,
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
use_flash_attn=_cfgget(config, "use_flash_attn", True),
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
gradient_checkpointing_use_reentrant=_cfgget(
config, "gradient_checkpointing_use_reentrant", False
),
)
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
if action_head_type != "flowmatching":
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
per_action_dim = _cfgget(config, "per_action_dim", 7)
action_dim = horizon * per_action_dim
if isinstance(config, dict):
config["horizon"] = horizon
config["per_action_dim"] = per_action_dim
config["action_dim"] = action_dim
self.horizon = horizon
self.per_action_dim = per_action_dim
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
def _normalize_image_batches(
self,
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
prompt: str | list[str] | None,
image_mask: torch.Tensor,
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
if not images:
raise ValueError("EVO1 expects at least one image per sample.")
first = images[0]
if isinstance(first, (Image.Image, torch.Tensor)):
image_batches = [list(images)] # type: ignore[arg-type]
else:
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
batch_size = len(image_batches)
if prompt is None:
prompts = [""] * batch_size
elif isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = [str(p) for p in prompt]
if len(prompts) != batch_size:
raise ValueError(
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
)
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
if image_mask.shape[0] != batch_size:
raise ValueError(
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
)
return image_batches, prompts, image_mask
def get_vl_embeddings(
self,
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
image_mask: torch.Tensor,
prompt: str | list[str] | None = None,
return_cls_only: bool | None = None,
) -> torch.Tensor:
if return_cls_only is None:
return_cls_only = self.return_cls_only
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
image_tensors_batch=image_batches,
image_masks=image_mask,
text_prompts=prompts,
return_cls_only=return_cls_only,
)
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
if isinstance(state_input, list):
state_tensor = torch.tensor(state_input)
elif isinstance(state_input, torch.Tensor):
state_tensor = state_input
else:
raise TypeError(f"Unsupported state input type: {type(state_input)}")
if state_tensor.ndim == 1:
state_tensor = state_tensor.unsqueeze(0)
return state_tensor.to(self._device)
def predict_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
):
if actions_gt is None:
return self.action_head.get_action(
fused_tokens,
state=state,
action_mask=action_mask,
embodiment_id=embodiment_ids,
)
return self.action_head(
fused_tokens,
state=state,
actions_gt=actions_gt,
action_mask=action_mask,
embodiment_id=embodiment_ids,
)
@torch.no_grad()
def run_inference(
self,
images: list[Image.Image | torch.Tensor],
image_mask: torch.Tensor,
prompt: str,
state_input: list | torch.Tensor,
return_cls_only: bool | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if image_mask.dim() == 1:
image_mask = image_mask.unsqueeze(0)
fused_tokens = self.get_vl_embeddings(
images=[images],
image_mask=image_mask,
prompt=[prompt],
return_cls_only=return_cls_only,
)
state_tensor = self.prepare_state(state_input)
action = self.predict_action(
fused_tokens,
state_tensor,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
action = action.to(torch.float32)
return action
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor | None = None,
actions_gt: torch.Tensor | None = None,
action_mask: torch.Tensor | None = None,
embodiment_ids: torch.Tensor | None = None,
):
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
def _set_module_trainable(self, module: nn.Module, trainable: bool):
for param in module.parameters():
param.requires_grad = trainable
def set_finetune_flags(self):
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
has_explicit_branch_flags = any(
flag is not None for flag in (finetune_language_model, finetune_vision_model)
)
finetune_language_model = bool(finetune_language_model)
finetune_vision_model = bool(finetune_vision_model)
finetune_vlm = bool(finetune_vlm)
if has_explicit_branch_flags:
self._set_module_trainable(self.embedder, False)
if hasattr(self.embedder.model, "language_model"):
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
if hasattr(self.embedder.model, "vision_model"):
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
if hasattr(self.embedder.model, "mlp1"):
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
elif not finetune_vlm:
self._set_module_trainable(self.embedder, False)
if not _cfgget(self.config, "finetune_action_head", False):
self._set_module_trainable(self.action_head, False)

View File

@@ -1,456 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import math
from types import SimpleNamespace
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def _cfgget(config, key: str, default=None):
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, dim: int, max_len: int = 1000):
super().__init__()
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, seq_len: int):
if seq_len > self.pe.size(1):
self._extend_pe(seq_len)
return self.pe[:, :seq_len, :]
def _extend_pe(self, new_max_len):
old_max_len, dim = self.pe.size(1), self.pe.size(2)
if new_max_len <= old_max_len:
return
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
extra_pe = extra_pe.unsqueeze(0)
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
self.pe = new_pe
class CategorySpecificLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
super().__init__()
self.num_categories = num_categories
if num_categories <= 1:
self.linear = nn.Linear(in_dim, out_dim)
else:
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
if self.num_categories <= 1:
if x.dtype != self.linear.weight.dtype:
x = x.to(dtype=self.linear.weight.dtype)
return self.linear(x)
if x.dtype != self.weight.dtype:
x = x.to(dtype=self.weight.dtype)
orig_shape = x.shape
x_flat = x.reshape(-1, orig_shape[-1])
if category_id.dim() == 0:
cid = category_id.item()
out = x_flat @ self.weight[cid] + self.bias[cid]
else:
category_id = category_id.reshape(-1)
if category_id.numel() != x_flat.size(0):
raise ValueError(
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
)
weight_selected = self.weight[category_id]
bias_selected = self.bias[category_id]
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
out_shape = orig_shape[:-1] + (out.shape[-1],)
return out.view(out_shape)
class CategorySpecificMLP(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
super().__init__()
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
self.activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
out = self.activation(self.fc1(x, category_id))
out = self.fc2(out, category_id)
return out
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
):
super().__init__()
self.horizon = horizon
self.embed_dim = embed_dim
self.num_categories = num_categories
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
self.activation = nn.ReLU(inplace=True)
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
batch_size, horizon, action_dim = action_seq.shape
assert self.horizon == horizon, "Action sequence length must match horizon"
x = action_seq.reshape(batch_size * horizon, action_dim)
if category_id.dim() == 0:
cat_ids = category_id.expand(horizon * batch_size)
else:
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
out = self.activation(self.W1(x, cat_ids))
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
out = out.view(batch_size, horizon, -1) + pos_enc
out = out.view(batch_size * horizon, -1)
out = self.activation(self.W2(out, cat_ids))
out = self.W3(out, cat_ids)
return out.view(batch_size, horizon, self.embed_dim)
class BasicTransformerBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
x = self.norm1(action_tokens)
attn_out, _ = self.attn(x, context_tokens, context_tokens)
x = action_tokens + attn_out
x2 = self.norm2(x)
if time_emb is not None:
x2 = x2 + time_emb.unsqueeze(1)
ff_out = self.ff(x2)
return x + ff_out
class FlowmatchingActionHead(nn.Module):
def __init__(
self,
config=None,
embed_dim: int = 896,
hidden_dim: int = 1024,
action_dim: int = 16 * 7,
horizon: int = 16,
per_action_dim: int = 7,
num_heads: int = 8,
num_layers: int = 8,
dropout: float = 0.0,
num_inference_timesteps: int = 20,
num_categories: int = 1,
):
super().__init__()
if config is not None:
embed_dim = _cfgget(config, "embed_dim", embed_dim)
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
action_dim = _cfgget(config, "action_dim", action_dim)
horizon = _cfgget(config, "horizon", horizon)
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
num_heads = _cfgget(config, "num_heads", num_heads)
num_layers = _cfgget(config, "num_layers", num_layers)
dropout = _cfgget(config, "dropout", dropout)
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
num_categories = _cfgget(config, "num_categories", num_categories)
self.config = config
else:
self.config = SimpleNamespace(
embed_dim=embed_dim,
hidden_dim=hidden_dim,
action_dim=action_dim,
horizon=horizon,
per_action_dim=per_action_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout=dropout,
num_inference_timesteps=num_inference_timesteps,
num_categories=num_categories,
)
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
self.embed_dim = embed_dim
self.horizon = horizon
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
hidden_dim=embed_dim * 4,
dropout=dropout,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(embed_dim)
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
self.mlp_head = CategorySpecificMLP(
input_dim=embed_dim,
hidden_dim=hidden_dim,
output_dim=action_dim,
num_categories=num_categories,
)
self.state_encoder = None
state_dim = _cfgget(self.config, "state_dim")
if state_dim is not None:
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
self.state_encoder = CategorySpecificMLP(
input_dim=state_dim,
hidden_dim=state_hidden,
output_dim=embed_dim,
num_categories=num_categories,
)
if horizon > 1:
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.per_action_dim,
embed_dim=embed_dim,
hidden_dim=embed_dim,
horizon=horizon,
num_categories=num_categories,
)
self.single_action_proj = None
else:
self.action_encoder = None
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
if self.horizon > 1 and self.action_encoder is not None:
return self.action_encoder(action_seq, embodiment_id)
if self.single_action_proj is None:
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
return self.single_action_proj(action_seq)
def _expand_action_mask(
self,
action_mask: torch.Tensor,
batch_size: int,
per_action_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if action_mask is None:
raise ValueError("action_mask must be provided for flow matching inference.")
if action_mask.dim() == 2:
expected_last_dim = self.horizon * per_action_dim
if action_mask.shape == (batch_size, expected_last_dim):
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
elif action_mask.shape == (batch_size, per_action_dim):
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
else:
raise ValueError(
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
)
elif action_mask.dim() == 3:
expected_shape = (batch_size, self.horizon, per_action_dim)
if tuple(action_mask.shape) != expected_shape:
raise ValueError(
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
)
expanded_mask = action_mask
else:
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
return expanded_mask.to(device=device, dtype=dtype)
def forward(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
actions_gt: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
state_mask: torch.Tensor = None,
action_mask: torch.Tensor = None,
):
if actions_gt is None:
return self.get_action(
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
)
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
context_tokens = fused_tokens
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
t = (
torch.distributions.Beta(2, 2)
.sample((batch_size,))
.clamp(0.02, 0.98)
.to(device)
.to(dtype=self.dtype)
)
time_index = (t * 999).long().clamp_(0, 999)
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
actions_gt_seq = actions_gt
noise = torch.rand_like(actions_gt) * 2 - 1
if action_mask is not None:
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
if action_mask.shape != noise.shape:
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
actions_gt_seq = actions_gt_seq * action_mask
noise = noise * action_mask
if self.horizon > 1:
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
else:
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
t_broadcast = t.view(batch_size, 1, 1)
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
target_dtype = self.dtype
action_tokens = action_tokens.to(dtype=target_dtype)
context_tokens = context_tokens.to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
return pred_velocity, noise
def get_action(
self,
fused_tokens: torch.Tensor,
state: torch.Tensor = None,
embodiment_id: torch.LongTensor = None,
action_mask: torch.Tensor = None,
):
batch_size = fused_tokens.size(0)
device = fused_tokens.device
if embodiment_id is None:
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
context_tokens = fused_tokens
if state is not None and self.state_encoder is not None:
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
action_seq = (
action.view(batch_size, self.horizon, per_action_dim)
if self.horizon > 1
else action.view(batch_size, 1, per_action_dim)
)
action_mask = self._expand_action_mask(
action_mask,
batch_size=batch_size,
per_action_dim=per_action_dim,
device=action_seq.device,
dtype=action_seq.dtype,
)
action_seq = action_seq * action_mask
target_dtype = self.dtype
context_tokens = context_tokens.to(dtype=target_dtype)
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
if num_steps <= 0:
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
dt = 1.0 / num_steps
for i in range(num_steps):
t = i / num_steps
time_index = min(int(t * 999), 999)
time_emb = (
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
)
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
action_seq = action_seq * action_mask
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
time_emb = time_emb.to(dtype=target_dtype)
x = action_tokens
for block in self.transformer_blocks:
x = block(x, context_tokens, time_emb)
x = self.norm_out(x)
if self.horizon > 1:
x_flat = x.reshape(batch_size, -1)
x_pooled = self.seq_pool_proj(x_flat)
else:
x_pooled = x.squeeze(1)
pred = self.mlp_head(x_pooled, embodiment_id)
action = action + dt * pred
action_seq = (
action.view(batch_size, self.horizon, per_action_dim)
if self.horizon > 1
else action.view(batch_size, 1, per_action_dim)
)
action_seq = action_seq * action_mask
return action_seq.reshape(batch_size, -1)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype

View File

@@ -1,372 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoTokenizer
else:
AutoModel = None
AutoTokenizer = None
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
IMG_START_TOKEN = "<img>" # nosec B105
IMG_END_TOKEN = "</img>" # nosec B105
logger = logging.getLogger(__name__)
def flash_attn_is_available() -> bool:
try:
import flash_attn # noqa: F401
except ModuleNotFoundError:
return False
return True
@functools.lru_cache(maxsize=10000)
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
aspect_ratio = orig_width / orig_height
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = orig_width * orig_height
for ratio in target_ratios:
target_ar = ratio[0] / ratio[1]
diff = abs(aspect_ratio - target_ar)
if diff < best_ratio_diff:
best_ratio_diff = diff
best_ratio = ratio
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
target_width = image_size * ratio_w
target_height = image_size * ratio_h
blocks = ratio_w * ratio_h
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
processed_images.append(resized_img.crop(box))
if use_thumbnail and len(processed_images) != 1:
processed_images.append(image.resize((image_size, image_size)))
return processed_images
class InternVL3Embedder(nn.Module):
def __init__(
self,
model_name="OpenGVLab/InternVL3-1B",
image_size=448,
device="cuda",
num_language_layers: int | None = 14,
model_dtype: str | torch.dtype = "bfloat16",
use_flash_attn: bool = True,
enable_gradient_checkpointing: bool = True,
gradient_checkpointing_use_reentrant: bool = False,
):
super().__init__()
self._requested_device = device
self.image_size = image_size
self.num_language_layers = num_language_layers
self.max_text_length = 1024
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
require_package("transformers", extra="evo1")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
if isinstance(model_dtype, str):
try:
model_dtype = getattr(torch, model_dtype)
except AttributeError as exc:
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
if use_flash_attn and not resolved_use_flash_attn:
logger.warning("flash_attn is not installed. Falling back to standard attention.")
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=model_dtype,
trust_remote_code=True,
use_flash_attn=resolved_use_flash_attn,
low_cpu_mem_usage=True,
_fast_init=False,
).to(self._requested_device)
if hasattr(self.model.language_model, "model"):
layers = self.model.language_model.model.layers
else:
layers = self.model.language_model.layers
if self.num_language_layers is not None:
layers = layers[: self.num_language_layers]
if hasattr(self.model.language_model, "model"):
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
else:
self.model.language_model.layers = torch.nn.ModuleList(layers)
self.model.language_model.lm_head = torch.nn.Identity()
self._configure_memory_features()
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
def _configure_memory_features(self) -> None:
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
if not self.enable_gradient_checkpointing:
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = False
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
if hasattr(language_model, "gradient_checkpointing_disable"):
language_model.gradient_checkpointing_disable()
elif hasattr(language_model, "gradient_checkpointing"):
language_model.gradient_checkpointing = False
if hasattr(language_model, "model"):
inner = language_model.model
if hasattr(inner, "gradient_checkpointing_disable"):
inner.gradient_checkpointing_disable()
elif hasattr(inner, "gradient_checkpointing"):
inner.gradient_checkpointing = False
return
def _enable_ckpt(module: nn.Module | None) -> bool:
if module is None:
return False
if hasattr(module, "gradient_checkpointing_enable"):
try:
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
except TypeError:
module.gradient_checkpointing_enable()
return True
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = True
return True
return False
enabled_any = _enable_ckpt(self.model)
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
self.model.vision_model.encoder.gradient_checkpointing = True
enabled_any = True
language_model = getattr(self.model, "language_model", None)
if language_model is not None:
enabled_any = _enable_ckpt(language_model) or enabled_any
if hasattr(language_model, "model"):
enabled_any = _enable_ckpt(language_model.model) or enabled_any
if hasattr(language_model, "config"):
language_model.config.use_cache = False
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if hasattr(self.model, "enable_input_require_grads"):
self.model.enable_input_require_grads()
if enabled_any:
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
else:
logger.warning(
"Requested gradient checkpointing, but model does not expose checkpointing controls."
)
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
if isinstance(image, torch.Tensor):
pil_image = to_pil_image(image.detach().cpu())
else:
pil_image = image.convert("RGB")
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
device=self.device, dtype=torch.bfloat16
)
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
return (tile_tensors - mean) / std
def _preprocess_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
) -> tuple[torch.Tensor, list[list[int]]]:
pixel_values_list = []
batch_num_tiles_list: list[list[int]] = []
for image_tensors in image_tensors_batch:
num_tiles_list: list[int] = []
for image in image_tensors:
tiles = self._preprocess_single_image(image)
pixel_values_list.append(tiles)
num_tiles_list.append(int(tiles.shape[0]))
batch_num_tiles_list.append(num_tiles_list)
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
pixel_values = torch.empty(
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
)
return pixel_values, batch_num_tiles_list
def _build_multimodal_prompts(
self,
batch_num_tiles_list: list[list[int]],
text_prompts: Sequence[str],
) -> list[str]:
prompts = []
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
prompt_segments = []
for i, tile_count in enumerate(num_tiles_list):
token_count = self.model.num_image_token * tile_count
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
prompts.append("".join(prompt_segments) + text_prompt.strip())
return prompts
def _prepare_and_fuse_embeddings(
self,
prompts: Sequence[str],
vit_embeds: torch.Tensor,
image_masks: torch.Tensor,
batch_num_tiles_list: list[list[int]],
) -> tuple[torch.Tensor, torch.Tensor]:
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
if true_sequence_length > self.max_text_length:
logger.warning(
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
self.max_text_length,
true_sequence_length,
)
model_inputs = self.tokenizer(
list(prompts),
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_text_length,
).to(self.device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
img_token_mask = input_ids == self.img_context_token_id
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
batch_size, _, channels = input_embeds.shape
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
tokens_per_tile = self.model.num_image_token
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
vit_idx = 0
for batch_index in range(batch_size):
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
mask_b = img_token_mask[batch_index]
actual_vis_tokens = actual_vis_tokens_list[batch_index]
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
vit_idx += expected_vis_tokens
if actual_vis_tokens > 0:
if item_vit_embeds.shape[0] < actual_vis_tokens:
raise ValueError(
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
)
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
current_token_idx = 0
img_token_locations = torch.where(mask_b)[0]
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
num_tokens_for_image = num_tiles * tokens_per_tile
if not bool(image_masks[batch_index, image_index].item()):
start_offset = current_token_idx
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
if start_offset < end_offset:
idxs = img_token_locations[start_offset:end_offset]
attention_mask[batch_index, idxs] = 0
current_token_idx += num_tokens_for_image
return input_embeds, attention_mask
def get_fused_image_text_embedding_from_tensor_images(
self,
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
image_masks: torch.Tensor,
text_prompts: Sequence[str],
return_cls_only: bool = True,
):
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
if pixel_values.shape[0] == 0:
logger.warning("InternVL3 received an empty image batch after preprocessing.")
hidden_size = getattr(self.model.config, "hidden_size", None)
if hidden_size is None and hasattr(self.model.language_model, "config"):
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
if hidden_size is None:
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
return empty
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
vit_embeds = self.model.extract_feature(pixel_values)
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
prompts,
vit_embeds,
image_masks.to(device=self.device),
batch_num_tiles_list,
)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
return_dict=True,
)
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device

View File

@@ -1,426 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections import deque
from contextlib import nullcontext
from pathlib import Path
import torch
from torch import Tensor
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.evo1_model import EVO1
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
class EVO1Policy(PreTrainedPolicy):
config_class = Evo1Config
name = "evo1"
def __init__(self, config: Evo1Config, **kwargs):
super().__init__(config)
config.validate_features()
if len(config.image_features) > config.max_views:
raise ValueError(
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
)
self.config = config
self.model = EVO1(self._build_model_config(config))
self.model.set_finetune_flags()
self.reset()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool | None = None,
**kwargs,
) -> T:
if strict is None:
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
@staticmethod
def _build_model_config(config: Evo1Config) -> dict:
return {
"device": config.device,
"return_cls_only": config.return_cls_only,
"vlm_name": config.vlm_model_name,
"vlm_num_layers": config.vlm_num_layers,
"vlm_dtype": config.vlm_dtype,
"use_flash_attn": config.use_flash_attn,
"action_head": config.action_head,
"action_horizon": config.chunk_size,
"per_action_dim": config.max_action_dim,
"state_dim": config.max_state_dim,
"embed_dim": config.embed_dim,
"hidden_dim": config.hidden_dim,
"state_hidden_dim": config.state_hidden_dim,
"num_heads": config.num_heads,
"num_layers": config.num_layers,
"dropout": config.dropout,
"num_inference_timesteps": config.num_inference_timesteps,
"num_categories": config.num_categories,
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
"finetune_vlm": config.finetune_vlm,
"finetune_language_model": config.finetune_language_model,
"finetune_vision_model": config.finetune_vision_model,
"finetune_action_head": config.finetune_action_head,
}
@property
def _camera_keys(self) -> list[str]:
return list(self.config.image_features)
@property
def _env_action_dim(self) -> int:
action_feature = self.config.action_feature
if action_feature is None:
return self.config.max_action_dim
return int(action_feature.shape[0])
@property
def _compute_dtype(self) -> torch.dtype:
return next(self.model.action_head.parameters()).dtype
@property
def _training_compute_dtype(self) -> torch.dtype:
if str(self.config.device).startswith("cuda"):
return torch.bfloat16
return self._compute_dtype
@property
def _inference_compute_dtype(self) -> torch.dtype:
if str(self.config.device).startswith("cuda") and self.config.use_amp:
return torch.bfloat16
return self._compute_dtype
def get_optim_params(self) -> list[dict]:
decay, no_decay = [], []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
is_bias = name.endswith("bias") or ".bias" in name
is_norm = param.dim() == 1 or "norm" in name.lower()
if is_bias or is_norm:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
def reset(self):
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
prompts = batch.get(self.config.task_field)
if prompts is None and self.config.task_field != "task":
prompts = batch.get("task")
if prompts is None:
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
if isinstance(prompts, str):
return [prompts]
if isinstance(prompts, (list, tuple)):
return [str(prompt) for prompt in prompts]
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if OBS_STATE not in batch:
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
state = batch[OBS_STATE]
if state.dim() == 1:
state = state.unsqueeze(0)
elif state.dim() == 3:
state = state[:, -1]
elif state.dim() != 2:
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
batch_size, state_dim = state.shape
if state_dim > self.config.max_state_dim:
raise ValueError(
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
)
explicit_mask = batch.get("state_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 1:
explicit_mask = explicit_mask.unsqueeze(0)
elif explicit_mask.dim() == 3:
explicit_mask = explicit_mask[:, -1]
elif explicit_mask.dim() != 2:
raise ValueError(
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, state_dim):
raise ValueError(
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
)
padded = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=state.dtype,
device=self.config.device,
)
padded[:, :state_dim] = state.to(device=self.config.device)
mask = torch.zeros(
batch_size,
self.config.max_state_dim,
dtype=torch.bool,
device=self.config.device,
)
if explicit_mask is None:
mask[:, :state_dim] = True
else:
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
if ACTION not in batch:
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
action = batch[ACTION]
if action.dim() == 2:
action = action.unsqueeze(1)
batch_size, horizon, action_dim = action.shape
if horizon != self.config.chunk_size:
raise ValueError(
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
)
if action_dim > self.config.max_action_dim:
raise ValueError(
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
)
explicit_mask = batch.get("action_mask")
if explicit_mask is not None:
if explicit_mask.dim() == 2:
if horizon == 1:
explicit_mask = explicit_mask.unsqueeze(1)
else:
raise ValueError(
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
)
elif explicit_mask.dim() != 3:
raise ValueError(
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
)
if explicit_mask.shape != (batch_size, horizon, action_dim):
raise ValueError(
"action_mask shape "
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
)
padded = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=action.dtype,
device=self.config.device,
)
padded[:, :, :action_dim] = action.to(device=self.config.device)
mask = torch.zeros(
batch_size,
horizon,
self.config.max_action_dim,
dtype=torch.bool,
device=self.config.device,
)
if explicit_mask is None:
mask[:, :, :action_dim] = True
else:
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
return padded.to(dtype=self._compute_dtype), mask
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
mask = torch.zeros(
batch_size,
self.config.max_action_dim,
dtype=torch.bool,
device=self.config.device,
)
mask[:, : self._env_action_dim] = True
return mask
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
embodiment_ids = batch.get("embodiment_id")
if embodiment_ids is None and self.config.embodiment_id_field:
embodiment_ids = batch.get(self.config.embodiment_id_field)
if embodiment_ids is None:
return torch.full(
(batch_size,),
self.config.default_embodiment_id,
dtype=torch.long,
device=self.config.device,
)
if embodiment_ids.dim() == 0:
embodiment_ids = embodiment_ids.unsqueeze(0)
elif embodiment_ids.dim() > 1:
embodiment_ids = embodiment_ids[:, -1]
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
if not camera_keys:
raise ValueError("EVO1 requires at least one visual observation feature.")
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
# from a real batch dim and not from C in the unbatched (C, H, W) case.
normalized: dict[str, Tensor] = {}
for camera_key in camera_keys[: self.config.max_views]:
image = batch[camera_key]
if image.dim() == 3:
image = image.unsqueeze(0)
elif image.dim() == 5:
image = image[:, -1]
elif image.dim() != 4:
raise ValueError(
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
)
normalized[camera_key] = image
batch_size = normalized[camera_keys[0]].shape[0]
image_batches: list[list[Tensor]] = []
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
for batch_index in range(batch_size):
sample_images: list[Tensor] = []
for camera_key in camera_keys[: self.config.max_views]:
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
if not sample_images:
raise ValueError("EVO1 received a batch without any image tensor.")
while len(sample_images) < self.config.max_views:
sample_images.append(torch.zeros_like(sample_images[0]))
image_batches.append(sample_images[: self.config.max_views])
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
return image_batches, image_masks
def _compute_fused_tokens(
self,
prompts: list[str],
image_batches: list[list[Tensor]],
image_masks: Tensor,
) -> Tensor:
fused_tokens = self.model.get_vl_embeddings(
images=image_batches,
image_mask=image_masks,
prompt=prompts,
return_cls_only=self.config.return_cls_only,
)
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
def _compute_masked_loss(
self,
pred_velocity: Tensor,
target_velocity: Tensor,
action_mask: Tensor,
reduction: str,
) -> Tensor:
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
active = flat_mask.sum(dim=1).clamp_min(1.0)
per_sample_loss = sq_error.sum(dim=1) / active
if reduction == "none":
return per_sample_loss
if reduction != "mean":
raise ValueError(f"Unsupported reduction '{reduction}'")
return sq_error.sum() / active.sum()
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
actions_gt, action_mask = self._prepare_actions(batch)
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
states = states.to(dtype=self._training_compute_dtype)
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
pred_velocity, noise = self.model(
fused_tokens,
state=states,
actions_gt=actions_gt,
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
embodiment_ids=embodiment_ids,
)
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
return loss, {
"loss": loss_mean,
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
}
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
prompts = self._normalize_task_batch(batch)
image_batches, image_masks = self._collect_image_batches(batch)
states, _state_mask = self._prepare_state(batch)
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
states = states.to(dtype=self._inference_compute_dtype)
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
action_mask = self._prepare_inference_action_mask(states.shape[0])
with (
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
if self.config.use_amp and str(self.config.device).startswith("cuda")
else nullcontext()
):
actions = self.model(
fused_tokens,
state=states,
action_mask=action_mask,
embodiment_ids=embodiment_ids,
)
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
return actions[:, :, : self._env_action_dim]
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
self.eval()
if len(self._action_queue) == 0:
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
self._action_queue.extend(action_chunk.transpose(0, 1))
return self._action_queue.popleft()

View File

@@ -1,106 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import (
batch_to_transition,
create_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import (
ACTION,
DONE,
INFO,
OBS_PREFIX,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
REWARD,
TRUNCATED,
)
def evo1_batch_to_transition(batch: dict[str, Any]):
transition = batch_to_transition(batch)
complementary_data = dict(transition.get("complementary_data") or {})
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
for key, value in batch.items():
if key in reserved or key.startswith(OBS_PREFIX):
continue
complementary_data.setdefault(key, value)
return create_transition(
observation=transition.get("observation"),
action=transition.get("action"),
reward=transition.get("reward", 0.0),
done=transition.get("done", False),
truncated=transition.get("truncated", False),
info=transition.get("info", {}),
complementary_data=complementary_data,
)
def make_evo1_pre_post_processors(
config: Evo1Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=evo1_batch_to_transition,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -46,8 +46,6 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .evo1.configuration_evo1 import Evo1Config
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -89,7 +87,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -148,14 +146,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .wall_x.modeling_wall_x import WallXPolicy
return WallXPolicy
elif name == "eo1":
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "evo1":
from .evo1.modeling_evo1 import EVO1Policy
return EVO1Policy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -173,7 +163,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "wall_x", "eo1", "evo1".
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -206,10 +196,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return XVLAConfig(**kwargs)
elif policy_type == "wall_x":
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "evo1":
return Evo1Config(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -413,20 +399,6 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
processors = make_eo1_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, Evo1Config):
from .evo1.processor_evo1 import make_evo1_pre_post_processors
processors = make_evo1_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:

View File

@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
use_group_norm: bool = False
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
# VQ-VAE
n_vqvae_training_steps: int = 20000

View File

@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_dtype(query_states.device.type)
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_dtype(query_states.device.type)
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype

View File

@@ -95,6 +95,13 @@ from .relative_action_processor import (
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
# `lerobot.datasets.language`, which requires the `[dataset]` extra
# (`datasets`, `pyarrow`). Importing it from the processor package would
# break every base-install consumer of `lerobot.processor`. Users that
# need it import directly:
# from lerobot.processor.render_messages_processor import RenderMessagesStep
__all__ = [
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",

View File

@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(

View File

@@ -167,12 +167,35 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
language_persistent_key = (
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
)
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
target_message_indices_key = (
{"target_message_indices": batch["target_message_indices"]}
if "target_message_indices" in batch
else {}
)
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
return {
**pad_keys,
**task_key,
**index_key,
**task_index_key,
**episode_index_key,
**timestamp_key,
**language_persistent_key,
**language_events_key,
**messages_key,
**message_streams_key,
**target_message_indices_key,
}
def create_transition(

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list) and len(value) == 1:
return _scalar(value[0])
return value

View File

@@ -54,7 +54,6 @@ class BiOpenArmFollower(Robot):
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_cameras,
side=config.left_arm_config.side,
@@ -73,7 +72,6 @@ class BiOpenArmFollower(Robot):
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=right_cameras,
side=config.right_arm_config.side,

View File

@@ -46,7 +46,7 @@ class LeKiwiConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = True
use_degrees: bool = False
@dataclass

View File

@@ -66,10 +66,6 @@ class OpenArmFollowerConfigBase:
# Whether to disable torque when disconnecting
disable_torque_on_disconnect: bool = True
# When True, expose `.vel` and `.torque` per motor in observation features.
# Default False for compatibility with the position-only openarm_mini teleoperator.
use_velocity_and_torque: bool = False
# Safety limit for relative target positions
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
max_relative_target: float | dict[str, float] | None = None

View File

@@ -93,9 +93,8 @@ class OpenArmFollower(Robot):
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
if self.config.use_velocity_and_torque:
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
features[f"{motor}.vel"] = float # Add this
features[f"{motor}.torque"] = float # Add this
return features
@property
@@ -236,9 +235,8 @@ class OpenArmFollower(Robot):
for motor in self.bus.motors:
state = states.get(motor, {})
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
if self.config.use_velocity_and_torque:
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
# Capture images from cameras
for cam_key, cam in self.cameras.items():

View File

@@ -48,6 +48,7 @@ from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -401,6 +402,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
shuffle = True
sampler = None
# Only swap in the language-aware collate when the dataset actually
# declares language columns; otherwise stay on PyTorch's default
# collate so non-language training runs are unaffected.
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
@@ -409,6 +414,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)

View File

@@ -49,7 +49,6 @@ class BiOpenArmLeader(Teleoperator):
can_data_bitrate=config.left_arm_config.can_data_bitrate,
motor_config=config.left_arm_config.motor_config,
manual_control=config.left_arm_config.manual_control,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
position_kd=config.left_arm_config.position_kd,
position_kp=config.left_arm_config.position_kp,
)
@@ -64,7 +63,6 @@ class BiOpenArmLeader(Teleoperator):
can_data_bitrate=config.right_arm_config.can_data_bitrate,
motor_config=config.right_arm_config.motor_config,
manual_control=config.right_arm_config.manual_control,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
position_kd=config.right_arm_config.position_kd,
position_kp=config.right_arm_config.position_kp,
)

View File

@@ -60,10 +60,6 @@ class OpenArmLeaderConfigBase:
# When enabled, motors have torque disabled for manual movement
manual_control: bool = True
# When True, expose `.vel` and `.torque` per motor in action features.
# Default False for compatibility with the position-only openarm_mini teleoperator.
use_velocity_and_torque: bool = False
# TODO(Steven, Pepijn): Not used ... ?
# MIT control parameters (used when manual_control=False for torque control)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]

View File

@@ -70,9 +70,8 @@ class OpenArmLeader(Teleoperator):
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
if self.config.use_velocity_and_torque:
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
return features
@property
@@ -202,9 +201,8 @@ class OpenArmLeader(Teleoperator):
for motor in self.bus.motors:
state = states.get(motor, {})
action_dict[f"{motor}.pos"] = state.get("position")
if self.config.use_velocity_and_torque:
action_dict[f"{motor}.vel"] = state.get("velocity")
action_dict[f"{motor}.torque"] = state.get("torque")
action_dict[f"{motor}.vel"] = state.get("velocity")
action_dict[f"{motor}.torque"] = state.get("torque")
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
preserved = {
key: [sample[key] for sample in batch if key in sample]
for key in _PYTHON_LIST_KEYS
if any(key in sample for sample in batch)
}
tensorizable = [
{
key: value
for key, value in sample.items()
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
}
for sample in batch
]
collated = default_collate(tensorizable)
collated.update(preserved)
return collated

View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
def test_message_recipe_validates_unknown_binding():
with pytest.raises(ValueError, match="unknown binding"):
TrainingRecipe(
messages=[
MessageTurn(role="user", content="${missing}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)

View File

@@ -385,3 +385,84 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0
# ── Tools accessor ───────────────────────────────────────────────────
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "no_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
assert meta.tools == DEFAULT_TOOLS
# info.json on disk should NOT include a `tools` key for clean datasets
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk
def test_tools_reads_declared_tools_from_info_json(tmp_path):
"""A `tools` list written into info.json survives load → meta.tools.
Regression test for the bug where ``DatasetInfo.from_dict`` silently
dropped the ``tools`` key (no matching dataclass field), so
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
what was on disk.
"""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "with_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/with_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
info_path = root / INFO_PATH
with open(info_path) as f:
raw = json.load(f)
raw["tools"] = [custom_tool]
with open(info_path, "w") as f:
json.dump(raw, f)
# Reload info from disk and rebind it on the metadata object
meta.info = load_info(root)
assert meta.tools == [custom_tool]
def test_tools_round_trip_through_dataset_info(tmp_path):
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
from lerobot.datasets.utils import DatasetInfo
raw = {
"codebase_version": "v3.1",
"fps": 30,
"features": SIMPLE_FEATURES,
"tools": [{"type": "function", "function": {"name": "say"}}],
}
info = DatasetInfo.from_dict(raw)
assert info.tools == raw["tools"]
assert info.to_dict()["tools"] == raw["tools"]

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
import pyarrow as pa # noqa: E402
from lerobot.datasets import LeRobotDataset # noqa: E402
from lerobot.datasets.io_utils import write_info # noqa: E402
from lerobot.datasets.language import ( # noqa: E402
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
VIEW_DEPENDENT_STYLES,
column_for_style,
is_view_dependent_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
validate_camera_field,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH # noqa: E402
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == [
"role",
"content",
"style",
"timestamp",
"camera",
"tool_calls",
]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_view_dependent_styles():
# motion lives in PERSISTENT_STYLES and is described in robot-frame
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
# (event) and trace (event, pixel-trajectory) carry a camera tag.
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
assert is_view_dependent_style("vqa")
assert is_view_dependent_style("trace")
assert not is_view_dependent_style("motion")
assert not is_view_dependent_style("subtask")
assert not is_view_dependent_style("plan")
assert not is_view_dependent_style("interjection")
assert not is_view_dependent_style(None)
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
validate_camera_field("vqa", "observation.images.top")
validate_camera_field("trace", "observation.images.front")
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("vqa", None)
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("trace", "")
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
validate_camera_field("subtask", None)
validate_camera_field("plan", None)
validate_camera_field("memory", None)
validate_camera_field("motion", None)
validate_camera_field("interjection", None)
validate_camera_field(None, None)
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("subtask", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("motion", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("interjection", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field(None, "observation.images.top")
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"camera": "observation.images.top",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []

View File

@@ -0,0 +1,378 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.datasets.language_render import ( # noqa: E402
active_at,
emitted_at,
nth_next,
nth_prev,
render_sample,
)
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"timestamp": timestamp,
"camera": camera,
"tool_calls": tool_calls,
}
def event_row(role, content, style, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"camera": camera,
"tool_calls": tool_calls,
}
PERSISTENT = [
persistent_row("assistant", "plan 0", "plan", 0.0),
persistent_row("assistant", "memory 0", "memory", 0.0),
persistent_row("assistant", "subtask 0", "subtask", 0.0),
persistent_row("assistant", "memory 1", "memory", 1.0),
persistent_row("assistant", "subtask 1", "subtask", 1.0),
]
EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
]
EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant",
None,
None,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
),
]
# Same emission tick, two cameras: triggers per-camera disambiguation in
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
# (vqa, user) + (vqa, assistant) pair per camera.
EVENTS_AT_3_TWO_CAMERAS = [
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
]
def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}'
)
def test_persistent_relative_resolvers_reject_event_styles():
with pytest.raises(ValueError, match="event-only"):
active_at(1.0, persistent=PERSISTENT, style="vqa")
with pytest.raises(ValueError, match="event-only"):
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
def test_nth_prev_and_next():
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
def test_substitution_if_present_multimodal_and_tool_calls():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${task}: ${interjection}"},
],
stream="high_level",
if_present="interjection",
),
MessageTurn(
role="assistant",
content="${plan}",
stream="high_level",
target=True,
tool_calls_from="speech",
),
],
bindings={"plan": "active_at(t, style=plan)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=EVENTS_AT_2,
t=2.0,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
assert rendered["messages"][1]["content"] == "plan 0"
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
assert rendered["message_streams"] == ["high_level", "high_level"]
assert rendered["target_message_indices"] == [1]
def test_exact_event_miss_returns_none_when_target_skips():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
]
)
assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling():
recipe = TrainingRecipe(
blend={
"a": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
],
),
"b": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
],
),
}
)
first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
assert first == second
def test_emitted_at_filters_vqa_by_camera():
top = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.top",
)
wrist = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.wrist",
)
assert top["content"] == '{"count": 3}'
assert wrist["content"] == '{"count": 1}'
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
with pytest.raises(ValueError, match="Ambiguous resolver"):
emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
)
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
return TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
},
messages=[
MessageTurn(
role="user",
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
)
@pytest.mark.parametrize(
("camera", "expected_query", "expected_answer"),
[
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
],
)
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
rendered = render_sample(
recipe=_vqa_subrecipe(camera),
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
assert rendered["messages"][0]["content"][0]["feature"] == camera
assert rendered["messages"][0]["content"][1]["text"] == expected_query
assert rendered["messages"][1]["content"] == expected_answer
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
rephrasings = [
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
# No explicit task override → resolver consults persistent rows.
seen: set[str] = set()
for sample_idx in range(64):
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=sample_idx,
dataset_ctx={"task": "canonical kitchen task"},
)
seen.add(rendered["messages"][0]["content"])
# Every rephrasing should be reachable across enough samples.
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT, # no task_aug rows
events=[],
t=0.0,
sample_idx=0,
dataset_ctx={"task": "clean the kitchen"},
)
assert rendered["messages"][0]["content"] == "clean the kitchen"
def test_resolve_task_explicit_override_beats_rephrasings():
rephrasings = [
persistent_row("user", "rephrased one", "task_aug", 0.0),
persistent_row("user", "rephrased two", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=0,
task="explicit override wins",
dataset_ctx={"task": "canonical"},
)
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_low_level_branch_renders_active_subtask():
low_level = TrainingRecipe(
blend={
"low": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(
role="user",
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
stream="high_level",
),
MessageTurn(
role="assistant",
content="${subtask}",
stream="low_level",
target=True,
),
],
)
}
)
rendered = render_sample(
recipe=low_level,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
assert rendered["message_streams"][-1] == "low_level"
assert rendered["target_message_indices"] == [1]

View File

@@ -1,193 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask

View File

@@ -1,186 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Smoke tests for EO1's public LeRobot policy interface."""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import torch
from torch import nn
pytest.importorskip("transformers")
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.eo1.modeling_eo1 import EO1Policy
from lerobot.utils.constants import ACTION, OBS_STATE
HIDDEN_SIZE = 8
STATE_DIM = 4
ACTION_DIM = 3
CHUNK_SIZE = 3
N_ACTION_STEPS = 2
MAX_ACTION_DIM = 6
STATE_TOKEN_ID = 5
ACTION_TOKEN_ID = 6
class DummyVLMBackbone(nn.Module):
def __init__(self, hidden_size: int, vocab_size: int = 64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.config = SimpleNamespace(text_config=SimpleNamespace(hidden_size=hidden_size))
@property
def model(self):
return self
def get_input_embeddings(self):
return self.embedding
def get_rope_index(
self,
input_ids: torch.Tensor,
image_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
mm_token_type_ids: torch.Tensor | None = None,
):
batch_size, seq_len = input_ids.shape
if attention_mask is None:
text_positions = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
else:
text_positions = attention_mask.long().cumsum(-1) - 1
text_positions = text_positions.masked_fill(attention_mask == 0, 0)
position_ids = text_positions.view(1, batch_size, seq_len).expand(3, batch_size, seq_len)
rope_deltas = torch.zeros(batch_size, 1, dtype=torch.long, device=input_ids.device)
return position_ids, rope_deltas
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
return gradient_checkpointing_kwargs
def gradient_checkpointing_disable(self):
return None
def forward(
self,
*,
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
):
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
return SimpleNamespace(
last_hidden_state=inputs_embeds,
past_key_values=SimpleNamespace(crop=lambda prefix_len: None),
)
def make_eo1_config():
from lerobot.policies.eo1.configuration_eo1 import EO1Config
return EO1Config(
device="cpu",
dtype="float32",
vlm_base="dummy-qwen",
vlm_config={},
chunk_size=CHUNK_SIZE,
n_action_steps=N_ACTION_STEPS,
max_state_dim=STATE_DIM,
max_action_dim=MAX_ACTION_DIM,
num_denoise_steps=2,
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
)
def make_policy_batch(include_action: bool) -> dict[str, torch.Tensor | int]:
batch_size = 1
seq_len = CHUNK_SIZE + 4
input_ids = torch.tensor(
[[11, STATE_TOKEN_ID, 12, ACTION_TOKEN_ID, ACTION_TOKEN_ID, ACTION_TOKEN_ID, 13]],
dtype=torch.long,
)
assert input_ids.shape == (batch_size, seq_len)
batch: dict[str, torch.Tensor | int] = {
OBS_STATE: torch.randn(batch_size, STATE_DIM, dtype=torch.float32),
"input_ids": input_ids,
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
"pixel_values": torch.zeros(batch_size, 3, 4, 4, dtype=torch.float32),
"image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.long),
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.int32),
"state_token_id": STATE_TOKEN_ID,
"action_token_id": ACTION_TOKEN_ID,
}
if include_action:
batch[ACTION] = torch.randn(batch_size, CHUNK_SIZE, ACTION_DIM, dtype=torch.float32)
return batch
def test_lerobot_eo1_forward_pass(monkeypatch):
monkeypatch.setattr(
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
)
policy = EO1Policy(make_eo1_config())
loss, metrics = policy.forward(make_policy_batch(include_action=True))
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["loss"] == pytest.approx(loss.item())
def test_lerobot_eo1_inference(monkeypatch):
monkeypatch.setattr(
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
)
policy = EO1Policy(make_eo1_config())
sample_calls = {"count": 0}
fixed_chunk = torch.tensor(
[
[
[0.1, 0.2, 0.3, 9.0, 9.0, 9.0],
[1.1, 1.2, 1.3, 9.0, 9.0, 9.0],
[2.1, 2.2, 2.3, 9.0, 9.0, 9.0],
]
],
dtype=torch.float32,
)
def fake_sample_actions(**kwargs):
sample_calls["count"] += 1
return fixed_chunk
monkeypatch.setattr(policy.model, "sample_actions", fake_sample_actions)
batch = make_policy_batch(include_action=False)
action_0 = policy.select_action(batch)
action_1 = policy.select_action(batch)
torch.testing.assert_close(action_0, fixed_chunk[:, 0, :ACTION_DIM])
torch.testing.assert_close(action_1, fixed_chunk[:, 1, :ACTION_DIM])
assert sample_calls["count"] == 1

View File

@@ -1,242 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from torch import nn
import lerobot.policies.evo1.modeling_evo1 as modeling_evo1
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
from lerobot.policies.factory import get_policy_class, make_policy_config
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
STATE_DIM = 4
ACTION_DIM = 3
MAX_STATE_DIM = 6
MAX_ACTION_DIM = 5
CHUNK_SIZE = 2
EMBED_DIM = 8
class DummyEVO1(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.action_head = nn.Linear(1, 1)
self.get_vl_embeddings_calls = 0
def set_finetune_flags(self):
return None
def get_vl_embeddings(self, images, image_mask, prompt=None, return_cls_only=False):
self.get_vl_embeddings_calls += 1
return torch.ones(len(images), 4, EMBED_DIM)
def forward(
self,
fused_tokens,
state=None,
actions_gt=None,
action_mask=None,
embodiment_ids=None,
):
batch_size = fused_tokens.shape[0]
if actions_gt is None:
return torch.ones(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
pred_velocity = torch.zeros(batch_size, CHUNK_SIZE * MAX_ACTION_DIM)
noise = torch.zeros_like(actions_gt)
return pred_velocity, noise
def make_config(training_stage="stage1", **kwargs):
config_kwargs = {
"device": "cpu",
"vlm_model_name": "dummy-internvl3",
"training_stage": training_stage,
"chunk_size": CHUNK_SIZE,
"n_action_steps": 1,
"max_state_dim": MAX_STATE_DIM,
"max_action_dim": MAX_ACTION_DIM,
"max_views": 2,
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"state_hidden_dim": 16,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"input_features": {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
"output_features": {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
}
config_kwargs.update(kwargs)
return Evo1Config(**config_kwargs)
def make_batch(include_action=True):
batch = {
"task": ["pick the block", "place the block"],
OBS_STATE: torch.randn(2, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(2, 3, 16, 16),
}
if include_action:
batch[ACTION] = torch.randn(2, CHUNK_SIZE, ACTION_DIM)
return batch
def test_evo1_factory_registration():
cfg = make_policy_config(
"evo1",
device="cpu",
vlm_model_name="dummy-internvl3",
input_features={
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
assert isinstance(cfg, Evo1Config)
assert get_policy_class("evo1") is modeling_evo1.EVO1Policy
def test_evo1_stage_defaults_and_consistency():
stage1 = make_config(training_stage="stage1")
assert (stage1.finetune_vlm, stage1.finetune_language_model, stage1.finetune_vision_model) == (
False,
False,
False,
)
assert stage1.finetune_action_head is True
stage2 = make_config(training_stage="stage2")
assert (stage2.finetune_vlm, stage2.finetune_language_model, stage2.finetune_vision_model) == (
True,
True,
True,
)
assert stage2.finetune_action_head is True
explicit_off = make_config(
training_stage="stage2",
finetune_vlm=False,
finetune_language_model=False,
finetune_vision_model=False,
finetune_action_head=False,
)
assert (
explicit_off.finetune_vlm,
explicit_off.finetune_language_model,
explicit_off.finetune_vision_model,
) == (
False,
False,
False,
)
assert explicit_off.finetune_action_head is False
try:
make_config(training_stage="stage2", finetune_vlm=True, finetune_language_model=False)
except ValueError as exc:
assert "Inconsistent EVO1 finetune config" in str(exc)
else:
raise AssertionError("Expected inconsistent finetune config to raise ValueError")
def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
loss, metrics = policy.forward(make_batch(include_action=True))
assert loss.ndim == 0
assert torch.isfinite(loss)
assert metrics["active_action_dims"] == ACTION_DIM * CHUNK_SIZE
assert policy.model.get_vl_embeddings_calls == 1
action_chunk = policy.predict_action_chunk(make_batch(include_action=False))
assert action_chunk.shape == (2, CHUNK_SIZE, ACTION_DIM)
policy.reset()
selected = policy.select_action(make_batch(include_action=False))
assert selected.shape == (2, ACTION_DIM)
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
# Regression for an issue where batch_size was read from shape[0] before normalizing
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
batch = {
OBS_STATE: torch.randn(1, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
}
image_batches, image_masks = policy._collect_image_batches(batch)
assert len(image_batches) == 1
assert len(image_batches[0]) == policy.config.max_views
assert image_masks.tolist() == [[True, False]]
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
config = make_config(chunk_size=1, n_action_steps=1)
policy = modeling_evo1.EVO1Policy(config)
batch = make_batch(include_action=True)
batch[ACTION] = torch.randn(2, ACTION_DIM)
batch["action_mask"] = torch.ones(2, ACTION_DIM, dtype=torch.bool)
actions, action_mask = policy._prepare_actions(batch)
assert actions.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask.shape == (2, 1, MAX_ACTION_DIM)
assert action_mask[:, :, :ACTION_DIM].all()
assert not action_mask[:, :, ACTION_DIM:].any()
def test_flowmatching_dict_config_enables_state_encoder_for_horizon_one():
head = FlowmatchingActionHead(
config={
"embed_dim": EMBED_DIM,
"hidden_dim": 16,
"action_dim": ACTION_DIM,
"horizon": 1,
"per_action_dim": ACTION_DIM,
"num_heads": 2,
"num_layers": 1,
"num_inference_timesteps": 2,
"state_dim": STATE_DIM,
"state_hidden_dim": 16,
"num_categories": 1,
}
)
assert head.state_encoder is not None
pred_velocity, noise = head(
torch.randn(2, 4, EMBED_DIM),
state=torch.randn(2, STATE_DIM),
actions_gt=torch.randn(2, 1, ACTION_DIM),
action_mask=torch.ones(2, 1, ACTION_DIM, dtype=torch.bool),
)
assert pred_velocity.shape == (2, ACTION_DIM)
assert noise.shape == (2, 1, ACTION_DIM)

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.processor.converters import create_transition # noqa: E402
from lerobot.processor.render_messages_processor import RenderMessagesStep # noqa: E402
from lerobot.types import TransitionKey # noqa: E402
def test_render_messages_step_noops_without_language_columns():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(complementary_data={"task": "do it"})
assert RenderMessagesStep(recipe)(transition) == transition
def test_render_messages_step_renders_and_drops_raw_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(
complementary_data={
"task": "do it",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [
{
"role": "assistant",
"content": "reach carefully",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
],
"language_events": [],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "language_persistent" not in data
assert "language_events" not in data
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]

View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.utils.collate import lerobot_collate_fn # noqa: E402
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
batch = [
{
"index": torch.tensor(0),
"messages": [{"role": "assistant", "content": "a"}],
"message_streams": ["low_level"],
"target_message_indices": [0],
"language_persistent": [{"content": "raw"}],
"language_events": [],
},
{
"index": torch.tensor(1),
"messages": [{"role": "assistant", "content": "b"}],
"message_streams": ["low_level"],
"target_message_indices": [0],
"language_persistent": [{"content": "raw"}],
"language_events": [],
},
]
out = lerobot_collate_fn(batch)
assert out["index"].tolist() == [0, 1]
assert out["messages"][0][0]["content"] == "a"
assert out["messages"][1][0]["content"] == "b"
assert out["message_streams"] == [["low_level"], ["low_level"]]
assert out["target_message_indices"] == [[0], [0]]
assert "language_persistent" not in out
assert "language_events" not in out
def test_lerobot_collate_passes_through_standard_batch():
"""On a non-language batch, the collate must match ``default_collate``.
Guards against silent regressions: ``lerobot_train.py`` only opts into
``lerobot_collate_fn`` when the dataset declares language columns, but
if a future change ever wires it in unconditionally we want the
behavior to remain a transparent pass-through for ordinary tensor
batches.
"""
from torch.utils.data._utils.collate import default_collate
batch = [
{
"observation.image": torch.zeros(3, 4, 4),
"action": torch.tensor([0.0, 1.0]),
"index": torch.tensor(0),
},
{
"observation.image": torch.ones(3, 4, 4),
"action": torch.tensor([2.0, 3.0]),
"index": torch.tensor(1),
},
]
custom = lerobot_collate_fn(batch)
expected = default_collate(batch)
assert custom.keys() == expected.keys()
for key in expected:
assert torch.equal(custom[key], expected[key]), f"key={key} diverged"
def test_lerobot_collate_drops_none_samples():
"""Recipes that yielded no target message return ``None`` — those samples
must be filtered out, and an entirely-``None`` batch must collapse to ``None``.
"""
batch = [None, {"index": torch.tensor(0)}, None]
out = lerobot_collate_fn(batch)
assert out is not None
assert out["index"].tolist() == [0]
assert lerobot_collate_fn([None, None]) is None

1167
uv.lock generated

File diff suppressed because it is too large Load Diff