Compare commits

...

21 Commits

Author SHA1 Message Date
Pepijn
8455efc474 feat(umi): simplify to derive_state_from_action and cam0-only
- Remove fix_dataset.py (user fixes dataset at source)
- evaluate.py: replace observation.pose/joints with observation.state
  (8D, derived from action during training, from FK at inference)
- evaluate.py: remove cam1 — training uses only cam0
- docs: rewrite workflow around derive_state_from_action=true,
  updated recompute-stats and training commands with
  relative_exclude_joints for gripper dims

Made-with: Cursor
2026-04-02 15:02:20 +02:00
Pepijn
e627d6442e feat(umi): add EE replay viewer, URDF meshes, and evaluate script updates
- Add replay.py script and replay_viewer.html for browser-based EE
  trajectory visualization from glannuzel/grabette-dataset
- Add viewer.html for interactive URDF inspection
- Move OpenArm URDF and meshes into openarm_follower/urdf/
- Add virtual EE target frame (openarm_right_ee_target) at 7cm from link7
- Adapt evaluate.py for single right-arm OpenArm with wrist camera
- Update docs with replay viewer usage
- Update openarm_follower config, driver, and kinematic processor

Made-with: Cursor
2026-04-02 14:25:24 +02:00
Pepijn
b08a62af89 feat(examples): adapt UMI pi0 evaluate script for OpenArm follower
Switch from SO100 to a single right OpenArm follower with one camera
(cam0 at 960x720). Strip dataset recording — just execute the policy.
Filter out .vel/.torque observation features for the EE pipeline.

Made-with: Cursor
2026-04-02 13:01:46 +02:00
Pepijn
d028978552 nit 2026-04-01 18:04:03 +02:00
Pepijn
58bd11caf3 refactor to use relative state 2026-04-01 17:23:58 +02:00
Pepijn
0fc855df13 fix 2026-04-01 15:29:59 +02:00
Pepijn
dfe16e8b84 fixes, do stats in seperate script (existing) 2026-04-01 13:59:44 +02:00
Pepijn
5ac3e568f1 add umi example 2026-04-01 13:48:06 +02:00
Pepijn
15934d8d08 feat(policies): add relative action support for pi0, pi0.5, and pi0_fast (#2970)
* Add option for pi family models to train with relative actions (relative to state)

* formatting

* add recomputation of stats and option to compute delta stats

* normalzie after delta conversion

* only recompute state for stats

* calulate chunk based stats

* sample 100k

* load from parquet

* sample 1m

* stats per chunck

* fix

* use quantiles

* stats for entire dataset

* fix

* max 1m frames

* compute before dist

* fix multi gpu processor bug

* Fix RTC with delta actions and OpenArms motor_type wiring

* feat: align pi0_fast delta actions with pi0/pi05 and add RTC integration tests

- Add delta_exclude_joints and action_feature_names to PI0FastConfig
- Move to_absolute_actions from modeling to processor pipeline for pi0_fast
- Add delta action detection and logging to eval_with_real_robot.py
- Add delta actions documentation to pi0 and pi05 READMEs
- Fix ruff lint issues in test_delta_actions.py
- Add test_rtc_delta_actions.py (24 tests) covering:
  - ActionQueue with delta vs absolute actions
  - RTC denoise step with delta leftovers
  - Full pipeline roundtrip (delta → RTC → absolute)
  - State rebasing approximation bounds
  - Non-delta policy compatibility
  - Multi-chunk consistency

* chore: clean up test comments, add OpenPI attribution, remove debug logging

- Replace decorative comment separators in test files with plain section headers
- Add attribution comments for 1e-6 epsilon in normalize_processor.py (from OpenPI)
- Remove debug logging blocks from lerobot_train.py

* refactor: extract compute_delta_action_stats into compute_stats.py

Move the ~70-line inline delta action stats block from lerobot_train.py
into a dedicated function in compute_stats.py, where all other stats
computation already lives. The training script now calls it in 6 lines.

* refactor: remove unused get_processed_left_over from ActionQueue

This method was never called outside of tests. Leftover actions for RTC
guidance are always retrieved via get_left_over() (delta/original space).

* revert: remove logging-only changes from eval_with_real_robot.py

The delta actions detection helper and log message added no functional
value — the script already handles delta policies correctly via the
processor pipeline.

* refactor: use ACTION/OBS_STATE constants instead of hardcoded strings

Replace hardcoded "action" and "observation.state" with ACTION and
OBS_STATE from utils.constants in compute_stats.py, dataset_tools.py,
and lerobot_train.py.

* style: remove stray blank lines in training loop

* refactor: move delta action stats to preprocessing step, remove on-the-fly computation

- Remove on-the-fly compute_delta_action_stats from lerobot_train.py
- Rewrite recompute_stats to delegate action stats to compute_delta_action_stats
  (chunk-based sampling matching what the model sees during training)
- Add chunk_size parameter to recompute_stats for delta action computation
- Add delta actions documentation to pi0.mdx and pi05.mdx

* feat: add recompute_stats CLI operation to lerobot-edit-dataset

* fix(tests): relax quantile normalization test tolerance for 1e-6 epsilon

* chore: remove agents_memory/pr_details.md from repo

* refactor: rename delta actions to relative actions throughout

What OpenPI calls "DeltaActions" is actually UMI's "relative trajectory"
representation: each action in the chunk is an offset from the current
state, not from the previous action. This avoids error accumulation.

Renamed across all source, tests, docs, and CLI:
- DeltaActionsProcessorStep → RelativeActionsProcessorStep
- to_delta_actions → to_relative_actions
- use_delta_actions → use_relative_actions
- delta_exclude_joints → relative_exclude_joints
- compute_delta_action_stats → compute_relative_action_stats
- delta_action_processor.py → relative_action_processor.py
- test_delta_actions.py → test_relative_actions.py

Kept as-is: AbsoluteActionsProcessorStep (converts TO absolute),
registry ID "delta_actions_processor" (backward compat), and unrelated
delta references (IK pipeline, Robosuite, RA-BC metrics, gym envs).

* docs: add Action Representations guide

Dedicated page explaining absolute, relative, and delta actions with
numerical examples, joint vs EE space, and how to use kinematics
pipelines and the relative action processor. References UMI paper
(Chi et al., 2024) for the terminology.

* docs: remove redundant OpenPI naming note from action representations

* docs: remove opinionated OpenPI reference from delta actions section

* docs: replace ASCII diagram with UMI paper figure

* docs: remove OpenPI reference from action representations

* docs: use HF-hosted image instead of local asset

* docs: clarify figure attribution

* revert: restore original normalization epsilon behavior

The 1e-6 unconditional epsilon change perturbed all normalized values,
breaking backward compatibility tests. The original approach (1e-8 eps
for MEAN_STD, conditional torch.where for QUANTILES) already handles
division by zero correctly without affecting non-degenerate cases.

* fix: restore delta_action_processor.py used by phone/RL teleop

The rename commit incorrectly deleted delta_action_processor.py and
duplicated its classes into relative_action_processor.py. Restore the
original file and import from it instead.

* fix(processor): address PR #2970 review comments

- Remove shebang from relative_action_processor.py (library module, not script)
- Add device alignment in to_relative_actions/to_absolute_actions so _last_state
  on CPU doesn't cause cross-device errors when actions are on CUDA
- Rename delta_step → relative_step in AbsoluteActionsProcessorStep for naming
  consistency; update factory.py, all processor files, and tests
- Expand _reconnect_relative_absolute_steps docstring to explain why post-hoc
  rewiring is needed after deserialization
- Fix off-by-one in compute_stats.py: sample_upper_bound = total_frames - chunk_size + 1
  so last valid start index is included and total_frames == chunk_size is not rejected
- Remove redundant NOTE comment in processor_pi05.py (duplicated two lines below)
- Fix pi0_fast processor ordering: move relative_step before NormalizerProcessorStep
  so normalizer sees delta actions (matching pi0/pi05); flip postprocessor to
  unnormalize → absolute accordingly. Relative stats are now required for all pi models
- Revert use_relative_joint_actions_aloha → use_delta_joint_actions_aloha in
  configuration_smolvla.py (preserve existing public API)
- Update action_representations.mdx: add missing joint to 6-DOF example, fix
  'based on a figure', clarify pi family ordering, add RTC compatibility section

* update rtc link

* feat: compute relative action stats over full dataset with optional parallelism

Remove the 100k sample cap from compute_relative_action_stats and process
all valid chunks. Vectorize with numpy (pre-load actions/states, fancy
indexing + broadcasting) for a large speedup over the per-index HF dataset
loop. Add num_workers param for thread-based parallelism (numpy releases
the GIL). Update docs to show --push_to_hub for recompute_stats.

* style: apply ruff formatting to compute_stats.py

* testing on real robot

* style: fix ruff format and remove redundant .keys() calls
2026-04-01 12:59:12 +02:00
Jai Kumaar Ratadia
9300352876 Fix SO-101 assembly instruction order to match videos (#3242)
* Fix SO-101 assembly instruction order to match videos

Motor horn installation steps were listed after placing motors
into the housing, but the assembly videos show installing horns
first. Reordered steps to match the videos, which is also the
easier approach since horns are harder to attach once the motor
is seated. Added missing detail that bottom horns don't require
screws.

* Update docs/source/so101.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

---------

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-03-31 12:16:34 +02:00
Steven Palma
720cf8e3a0 Revert "fix(deps): breaking change from transformers 5.4.0" (#3249)
* Revert "fix(deps): breaking change from transformers 5.4.0 (#3231)"

This reverts commit 07502868e5.

* chore(dependecies): pin transformers to 5.3.0 temporarily
2026-03-30 19:11:41 +02:00
Steven Palma
5d4fdf5088 feat(scripts): add transformers version (#3248)
* feat(scripts): add transformers and torch version

* chore(scripts): remove pytorch
2026-03-30 16:33:17 +02:00
四七
3b185f7f9d fix(datasets): remove unreachable timestamp branch in add_frame (#3163)
* fix(datasets): remove unreachable timestamp branch in add_frame and document caller contract

- Remove dead code: frame.pop("timestamp") branch in add_frame() could never
  execute because validate_frame() raises ValueError for any DEFAULT_FEATURES
  key (including timestamp) before we reach that line.
- Expand add_frame() docstring: explicitly document that timestamp and
  frame_index must NOT be passed by the caller.
- Add explanatory comment in validate_frame(): clarifies why DEFAULT_FEATURES
  are excluded from expected_features, preventing future re-introduction of
  the dead branch.

The dead branch originated in #1200, which fixed a shape-(1,) mismatch for a
code path that was subsequently made unreachable by a refactor of validate_frame.

* chore(datasets): narrow PR scope

* fix(datasets): move add_frame timestamp cleanup to dataset_writer
2026-03-28 11:37:57 +01:00
Bryson Jones
2e069b1c47 Feature/add multitask diffusion transformer policy implementation (#2545)
* Add multitask diffusion transformer policy

Add multitask diffusion transformer policy

* expand the observation encoder to support differnt size encoders for vision and text

* add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs

* update readme and citations for multitask dit policy

* remove dino vision encoder and simplify text and vision encoders by removing inheritance structure

* adjust factory comment

* update docstring for multitask dit policy processor file

* simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

* add references to the modeling file comments

* merge all modules files into the main modeling file

* add torch.no_grad decorators

* split up select action return statement

* remove redundant asserts

* add tutorial to training with multi_task_dit

* fix bugs when testing on hardware

* remove environment state conditioning

* update typo in test instruction comment

* add processor tests to multitask dit tests

* move policy to top of file

* use constants for indexing into batches and remove env state references

* remove the base classes since we don't need to be able to extend

* fix nit formatting in generate actions fcn

* reformat and clean up tutorial for multitask dit policy

* add more descriptions and depth to multitask dit tutorial

* note origins of each training objective

* rename config param for multiple vision encoders

* refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit

* add multitask dit to toc for docs

* add conditional transformers import to match all other policies that use transformers lib

* add test handling for multitask dit when transformers isnt available

* skip tests without transformers

* remove cropping of images smaller than the crop size

* add kwargs arg to multitask dit constructor

* add wallx dep conflict management for multitask dit policy

* use hyphens for cleanliness in pyproject.toml

* add conflict management to pyproject toml for pi conflict for mtdp as well

* update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time

* revert fast tests edits

* update docs and readme files, fixing some typos and adding multitask dit to readme

* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy

* chore(dependencies): bump pi0 family to transformers v5

* chore(dependencies): bump wall x to transformers v5

* chore(dependencies): bump gr00t to transformers v5

* chore(style): fix pre-commit

* fix(policy): xvla forced_bos_token missing

* test(rl): skip ci tests for resnet10

* Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>

* chore: fix XVLA in transformers v5 (#3006)

* test(policies): enable wall x CI testing

* style(test): pre-commit check

* style(test): pre-commit

---------

Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Yufei Sun <skieyfly@gmail.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-28 00:41:26 +01:00
Steven Palma
4e45acca52 fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)
* refactor(dataset): enhance dataset root directory handling and introduce hub cache support

- Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads.
- Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management.
- Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory.

* refactor(dataset): improve root directory handling in LeRobotDataset

- Updated LeRobotDataset to store the requested root path separately from the actual root path.
- Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management.

* refactor(dataset): minor improvements for hub cache support

* chore(datasets): guard in resume + assertion test

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
2026-03-27 22:21:55 +01:00
Maxime Ellerbach
975d89b38d chore(docs): add more guidance to bring your own policies tutorial (#3230)
* chore(docs): add more guidance to bring your own policies tutorial

* removing normalization to avoid confusion with processors

* trailing whitespace

* Update docs/source/bring_your_own_policies.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update docs/source/bring_your_own_policies.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding get optim params and predict_action chunk

* removing extra quotes

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-03-27 21:25:37 +01:00
Maxime Ellerbach
07502868e5 fix(deps): breaking change from transformers 5.4.0 (#3231)
* fix(deps): breaking change from transformers 5.4.0

* Update src/lerobot/policies/xvla/modeling_florence2.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* removing dataclass

* bumping transformers 5.4.0

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-27 21:25:12 +01:00
Reece O'Mahoney
aa9cc9bd43 fix(logging): suppress noisy httpx INFO logs (#3173)
Set httpx logger level to WARNING in init_logging to prevent
HTTP request traces from flooding the terminal during train and
eval scripts.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-26 21:05:15 +01:00
Steven Palma
123495250b refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180)
* refactor(dataset): split reader and writer

* chore(dataset): remove proxys

* refactor(dataset): better reader & writer encapsulation

* refactor(datasets): clean API + reduce leaky implementations

* refactor(dataset): API cleaning for writer, reader and meta

* refactor(dataset): expose writer & reader + other minor improvements

* refactor(dataset): improve teardown routine

* refactor(dataset): add hf_dataset property at the facade level

* chore(dataset): add init for datasset module

* docs(dataset): add docstrings for public API of the dataset classes

* tests(dataset): add tests for new classes

* fix(dataset): remove circular dependecy
2026-03-26 19:09:25 +01:00
Jade Choghari
017ff73fbf chore(docs): add rename map and empty cam guide (#3065)
* add blog/guide

* add to tree

* chore(docs): rephrase rename_map docs for clarity and simplicity

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-23 13:57:53 -07:00
Praedico
f90db58c15 docs(async): fix GitHub issues link (#3186) 2026-03-19 22:32:07 -07:00
109 changed files with 10570 additions and 1505 deletions

2
.gitignore vendored
View File

@@ -173,7 +173,5 @@ outputs/
# Dev folders
.cache/*
*.stl
*.urdf
*.xml
*.part

View File

@@ -100,11 +100,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub

View File

@@ -19,6 +19,10 @@
title: Multi GPU training
- local: peft_training
title: Training with PEFT (e.g., LoRA)
- local: rename_map
title: Using Rename Map and Empty Cameras
- local: umi_pi0_relative_ee
title: UMI Data with pi0 Relative EE Actions
title: "Tutorials"
- sections:
- local: lerobot-dataset-v3
@@ -47,6 +51,8 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"
@@ -83,6 +89,8 @@
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
- local: action_representations
title: Action Representations
title: "Robot Processors"
- sections:
- local: so101

View File

@@ -0,0 +1,238 @@
# Action Representations
This guide explains the different ways robot actions can be represented in LeRobot, how they relate to each other, and when to use each one.
## Joint Space vs End-Effector Space
Before discussing action representations, it helps to understand the two coordinate spaces actions can live in.
### Joint Space
Joint-space actions directly specify target positions for each motor. For a 6-DOF arm with a gripper, a joint-space action might look like:
```
action = [shoulder_pan: 45.0, shoulder_lift: -20.0, elbow: -30.0, wrist_pitch: 10.0, wrist_roll: 0.0, wrist_yaw: 5.0, gripper: 0.8]
```
Joint space is the default in LeRobot. It is simple, requires no kinematics model, and maps directly to motor commands. Most beginner setups (SO-100, Koch) use joint-space actions.
### End-Effector (EE) Space
End-effector-space actions specify the desired position and orientation of the robot's tool tip (gripper) in Cartesian coordinates:
```
action = [x: 0.25, y: -0.10, z: 0.15, wx: 0.0, wy: 0.0, wz: 0.1, gripper: 0.8]
```
EE space is more intuitive for tasks like pick-and-place because it directly describes where the gripper should go, but it requires a kinematics model (URDF) to convert between EE poses and joint angles.
### Converting Between Spaces
LeRobot provides processor steps for converting between joint and EE spaces using forward and inverse kinematics. These are built on top of `RobotKinematics`, which loads a URDF model of your robot.
```python
from lerobot.model.kinematics import RobotKinematics
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
kinematics = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=["shoulder", "elbow", "wrist_pitch", "wrist_roll", "wrist_yaw"],
)
# Joints → EE (for observations: "where is my gripper?")
fk_step = ForwardKinematicsJointsToEE(kinematics=kinematics, motor_names=[...])
# EE → Joints (for actions: "move my gripper here")
ik_step = InverseKinematicsEEToJoints(kinematics=kinematics, motor_names=[...])
```
See [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) for a complete working example of recording, replaying, and evaluating with EE-space actions on an SO-100 arm.
## Absolute, Relative, and Delta Actions
Regardless of whether you work in joint space or EE space, the action values can be expressed in three different ways. The terminology follows [UMI (Chi et al., 2024)](https://arxiv.org/abs/2402.10329).
### Absolute Actions (LeRobot default)
Each action specifies the target position directly.
**Example** (joint space, chunk of 4):
```
current_state = [45.0, -30.0, 10.0]
action_chunk = [
[46.0, -29.0, 11.0], # go to 46, -29, 11
[47.5, -27.0, 12.0], # go to 47.5, -27, 12
[49.0, -25.0, 13.5], # go to 49, -25, 13.5
[50.0, -24.0, 15.0], # go to 50, -24, 15
]
```
Each value is a target position in the robot's coordinate frame. Simple and direct, but requires a consistent global coordinate frame. This is the default in LeRobot.
### Relative Actions (used by OpenPI / pi0)
Each action in the chunk is an offset from the **current state at the moment of prediction**. All actions in the chunk share the same reference point:
```
current_state = [45.0, -30.0, 10.0]
relative_chunk = [
[1.0, 1.0, 1.0], # +1 from current → target 46, -29, 11
[2.5, 3.0, 2.0], # +2.5 from current → target 47.5, -27, 12
[4.0, 5.0, 3.5], # +4 from current → target 49, -25, 13.5
[5.0, 6.0, 5.0], # +5 from current → target 50, -24, 15
]
```
The conversion is straightforward: `relative = absolute - current_state`. To recover absolute: `absolute = relative + current_state`.
**Why use relative actions?** The model learns to predict offsets centered around zero, which is easier to normalize and leads to more stable training. Because every chunk references the same current state, there is no error accumulation across chunks.
### Delta Actions (sequential differences)
Each action is an offset from the **previous action** (or from the current state for the first step):
```
current_state = [45.0, -30.0, 10.0]
delta_chunk = [
[1.0, 1.0, 1.0], # current → 46, -29, 11
[1.5, 2.0, 1.0], # previous action → 47.5, -27, 12
[1.5, 2.0, 1.5], # previous action → 49, -25, 13.5
[1.0, 1.0, 1.5], # previous action → 50, -24, 15
]
```
Here each step is relative to the one before it. To recover absolute positions you must sum all previous deltas, which means errors accumulate over time. UMI explicitly argues against this representation for this reason.
### Visual Comparison
The figure below (based on a figure from [UMI, Chi et al., 2024](https://arxiv.org/abs/2402.10329)) illustrates the key difference. With **relative trajectory**, every action in the chunk points back to the same origin (current state), so a new inference step cleanly resets the reference. With **delta**, each action depends on the previous one, so errors accumulate. **Absolute** actions require a consistent global coordinate frame.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/action_representations_umi.png"
alt="Relative Trajectory as Action Representation (UMI, Chi et al., 2024)"
width="85%"
/>
## Using Relative Actions in LeRobot
LeRobot provides `RelativeActionsProcessorStep` to convert between absolute and relative actions inside the processor pipeline. This is how pi0, pi0.5, and pi0_fast support relative actions.
> **Note:** All pi models (pi0, pi0.5, pi0*fast) apply relative conversion \_before* normalization (`relative → normalize`), so the normalizer always sees delta (relative) values. This means **relative action stats are required** for all of them when training with `use_relative_actions=true`. In pi0_fast the `RelativeActionsProcessorStep` only modifies the action — the state observation is unchanged — so `NormalizerProcessorStep` still runs before the state tokenizer and the tokenizer continues to receive normalized state as expected.
### How it works
During **training** (preprocessing), actions are converted from absolute to relative before the model sees them:
```
raw absolute action → RelativeActionsProcessorStep → normalize → model
```
During **inference** (postprocessing), model predictions are converted back to absolute before being sent to the robot:
```
model output → unnormalize → AbsoluteActionsProcessorStep → robot
```
The `AbsoluteActionsProcessorStep` reads the cached current state from its paired `RelativeActionsProcessorStep`, so the two must be wired together (handled automatically by the policy factory).
### Enabling relative actions for the pi family (pi0, pi0.5, pi0_fast)
**Step 1**: Precompute relative action statistics for your dataset:
```bash
lerobot-edit-dataset \
--repo_id your_dataset \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']"
```
**Step 2**: Train with relative actions enabled:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
The `relative_exclude_joints` parameter specifies joints that should remain in absolute space. For example, gripper commands are typically binary (open/close) and don't benefit from relative encoding.
### Combining relative actions with RTC
[RTC](https://arxiv.org/abs/2506.07339) runs policy inference at high frequency and sends actions to the robot as they are predicted rather than waiting for a full chunk. Relative actions and RTC are fully compatible: because every chunk in relative mode references the **same** current state (captured at the start of inference), each predicted action in the chunk remains a valid offset even if the robot has already moved. No special handling is needed — `RelativeActionsProcessorStep` caches the state once per inference call and `AbsoluteActionsProcessorStep` applies it to every action in the streamed output.
### Combining relative actions with EE space
Relative actions work in both joint space and EE space. For example, if your dataset stores EE actions, relative encoding converts them to offsets from the current EE pose:
```
current_ee_state = [x: 0.25, y: -0.10, z: 0.15, gripper: 0.8]
absolute_ee_chunk = [
[0.26, -0.09, 0.16, 0.8],
[0.28, -0.07, 0.18, 0.8],
]
relative_ee_chunk = [
[0.01, 0.01, 0.01, 0.0], # offset from current EE pose
[0.03, 0.03, 0.03, 0.0], # offset from current EE pose
]
```
## Processing Pipeline Summary
Here is how the different processors compose. Each arrow is a processor step, and they can be chained in a `RobotProcessorPipeline` or `PolicyProcessorPipeline`:
```
┌─────────────────────────────────────────┐
Action Space │ Joint Space ←──IK──→ EE Space │
│ ForwardKinematicsJointsToEE │
│ InverseKinematicsEEToJoints │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
State Derivation │ Action column ────→ State + Action │
│ DeriveStateFromActionStep (pre only) │
│ (UMI-style: state from action chunk) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Action Repr. │ Absolute ←────→ Relative │
│ RelativeActionsProcessorStep (pre) │
│ AbsoluteActionsProcessorStep (post) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
State Repr. │ Absolute ────→ Relative │
│ RelativeStateProcessorStep (pre only) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Normalization │ Raw ←────→ Normalized │
│ NormalizerProcessorStep (pre) │
│ UnnormalizerProcessorStep (post) │
└─────────────────────────────────────────┘
```
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output).
With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details.
## References
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
- [Introduction to Processors](./introduction_processors) - How processor pipelines work in LeRobot.
- [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) - Complete example of recording and evaluating with EE-space actions.

View File

@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).

View File

@@ -41,13 +41,15 @@ requires = # your-build-system
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
```python
# configuration_my_custom_policy.py
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("my_custom_policy")
@dataclass
@@ -61,22 +63,56 @@ class MyCustomPolicyConfig(PreTrainedConfig):
hidden_dim: Hidden dimension for the policy network
# Add your policy-specific parameters here
"""
# ...PreTrainedConfig fields...
pass
horizon: int = 50
n_action_steps: int = 50
hidden_dim: int = 256
optimizer_lr: float = 1e-4
optimizer_weight_decay: float = 1e-4
def __post_init__(self):
super().__post_init__()
# Add any validation logic here
if self.n_action_steps > self.horizon:
raise ValueError("n_action_steps cannot exceed horizon")
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
# Implement validation logic for your policy's requirements
pass
if not self.image_features:
raise ValueError("MyCustomPolicy requires at least one image feature.")
if self.action_feature is None:
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self):
return None
@property
def observation_delta_indices(self) -> list[int] | None:
"""Relative timestep offsets the dataset loader provides per observation.
Return `None` for single-frame policies. For temporal policies that consume
multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for
3 past frames at stride 10 and 1 future frame at stride 10.
"""
return None
@property
def action_delta_indices(self) -> list[int]:
"""Relative timestep offsets for the action chunk the dataset loader returns.
"""
return list(range(self.horizon))
@property
def reward_delta_indices(self) -> None:
return None
```
## Step 3: Implement the Policy Class
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
```python
# modeling_my_custom_policy.py
@@ -85,38 +121,74 @@ import torch.nn as nn
from typing import Any
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from .configuration_my_custom_policy import MyCustomPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
config.validate_features() # not called automatically by the base class
self.config = config
self.model = ... # your nn.Module here
def reset(self):
"""Reset episode state."""
...
def get_optim_params(self) -> dict:
"""Return parameters to pass to the optimizer (e.g. with per-group lr/wd)."""
return {"params": self.parameters()}
def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return the full action chunk (B, chunk_size, action_dim) for the current observation."""
...
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return a single action for the current timestep (called at inference)."""
...
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Compute the training loss.
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
timesteps padded because the episode ended before `horizon` steps, you
can exclude those from your loss.
"""
actions = batch[ACTION]
action_is_pad = batch.get("action_is_pad")
...
return {"loss": ...}
```
## Step 4: Add Data Processors
Create processor functions:
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
```python
# processor_my_custom_policy.py
from typing import Any
import torch
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
def make_my_custom_policy_pre_post_processors(
config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessing and postprocessing functions for your policy."""
pass # Define your preprocessing and postprocessing logic here
preprocessor = ... # build your PolicyProcessorPipeline for inputs
postprocessor = ... # build your PolicyProcessorPipeline for outputs
return preprocessor, postprocessor
```
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
## Step 5: Package Initialization
Expose your classes in the package's `__init__.py`:

View File

@@ -424,7 +424,7 @@ robot = SO100Follower(robot_config)
robot.connect()
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[episode_idx])
actions = dataset.hf_dataset.select_columns("action")
actions = dataset.select_columns("action")
log_say(f"Replaying episode {episode_idx}")
for idx in range(dataset.num_frames):

View File

@@ -0,0 +1,340 @@
# Multitask DiT Policy
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
## Model Overview
The model uses:
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.
## Installation Requirements
Multitask DiT Policy has additional dependencies. Install it with:
```bash
pip install lerobot[multi_task_dit]
```
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
## Usage
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
```python
policy.type=multi_task_dit
```
## Training
### Basic Training Command
Here's a complete training command for training Multitask DiT on your dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
For reliable performance, start with these suggested default hyperparameters:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
**Key Parameters:**
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task
### Training Configuration Parameters
#### Objective Selection
Choose between diffusion and flow matching:
```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```
#### Transformer Architecture
Adjust model capacity based on dataset size:
```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```
**Positional Encoding Options:**
The model supports two positional encoding methods for action sequences:
```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE
# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```
**Other Transformer Parameters:**
```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```
#### Vision Encoder Configuration
```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```
#### Text Encoder Configuration
```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```
#### Learning Rate Configuration
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```
### Training Tuning Guidelines
#### 1. Flow Matching with Beta Sampling
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
Consider testing the flow-matching objective and evaluating performance differences for your task:
```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
#### 2. Number of Transformer Layers
Match model capacity to your dataset size:
- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers
#### 3. `horizon` Tuning
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
#### 4. `n_action_steps` Sensitivity
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
### Inference Tuning
For faster inference, use DDIM with fewer sampling steps:
```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```
### Resuming Training
To resume training from a checkpoint:
```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
## Common Failure Modes and Debugging
Training these models can be finicky. Here are common failure modes and debugging approaches:
### Idling / No Motion
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
**Debugging tips:**
- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction
### Executing the Wrong Task
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
**Potential causes:**
- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset
**Debugging tips:**
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning
### Training Instability
If training loss is unstable or diverging:
- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly
## Performance Considerations
### GPU Requirements
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
### Batch Size Recommendations
- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)
## Example: Training on Custom Dataset
Here's a complete example training on a custom dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```
## References
For more details on the technical implementation and architecture, see:
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)

View File

@@ -91,6 +91,46 @@ lerobot-train \
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
## Relative Actions
By default, π₀ predicts absolute actions. You can enable **relative actions** so the model predicts offsets relative to the current robot state. This can improve training stability for certain setups.
To use relative actions, first recompute your dataset stats in relative space via the CLI:
```bash
lerobot-edit-dataset \
--repo_id your_dataset \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--push_to_hub true
```
Or equivalently in Python:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.dataset_tools import recompute_stats
dataset = LeRobotDataset("your_dataset")
recompute_stats(dataset, relative_action=True, chunk_size=50, relative_exclude_joints=["gripper"])
dataset.push_to_hub()
```
The `chunk_size` should match your policy's `chunk_size` (default 50 for π₀). `relative_exclude_joints` lists joint names that should remain in absolute space (e.g. gripper commands). Use `--push_to_hub true` to upload the updated stats to the Hub.
Then train with relative actions enabled:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]' \
...
```
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

View File

@@ -97,6 +97,46 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
## Relative Actions
By default, π₀.₅ predicts absolute actions. You can enable **relative actions** so the model predicts offsets relative to the current robot state. This can improve training stability for certain setups.
To use relative actions, first recompute your dataset stats in relative space via the CLI:
```bash
lerobot-edit-dataset \
--repo_id your_dataset \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--push_to_hub true
```
Or equivalently in Python:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.dataset_tools import recompute_stats
dataset = LeRobotDataset("your_dataset")
recompute_stats(dataset, relative_action=True, chunk_size=50, relative_exclude_joints=["gripper"])
dataset.push_to_hub()
```
The `chunk_size` should match your policy's `chunk_size` (default 50 for π₀.₅). `relative_exclude_joints` lists joint names that should remain in absolute space (e.g. gripper commands). Use `--push_to_hub true` to upload the updated stats to the Hub.
Then train with relative actions enabled:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]' \
...
```
## Performance Results
### Libero Benchmark Results

View File

@@ -0,0 +1,37 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```

114
docs/source/rename_map.mdx Normal file
View File

@@ -0,0 +1,114 @@
# Rename Map and Empty Cameras
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
## Why observation keys don't always match
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
## Using the rename map
Pass the mapping as a JSON string on the command line. The convention is always:
```
--rename_map='{"source_key": "policy_key", ...}'
```
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
### Training
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
```
### Evaluation
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
```bash
lerobot-eval \
--policy.path=lerobot/pi0fast-libero \
--env.type=libero \
... \
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
```
### Recording
`lerobot-record` also supports rename maps, nested under the dataset config:
```bash
lerobot-record \ # When running inference
--policy.path="<user>/smolVLA_finetuned" \
... \
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
```
## Alternative: edit the policy config directly
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
## Empty cameras: fewer views than the policy expects
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
### How it works
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
```
observation.images.empty_camera_0
observation.images.empty_camera_1
...
```
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
### Example
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
1. Set `--policy.empty_cameras=1`.
2. The config adds a third key: `observation.images.empty_camera_0`.
3. Use the rename map for your two real cameras as usual.
4. The third slot is masked out — no fake images needed in your dataset.
## Quick reference
| Goal | What to do |
| ----------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |

View File

@@ -236,10 +236,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 1
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Place the first motor into the base.
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
- Install both motor horns, securing the top horn with a M3x6mm screw.
- Attach the shoulder part.
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
- Add the shoulder motor holder.
@@ -255,9 +255,9 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 2
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Slide the second motor in from the top.
- Fasten the second motor with 4 M2x6mm screws.
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
- Attach the upper arm with 4 M3x6mm screws on each side.
<div class="video-container">
@@ -271,8 +271,8 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 3
- Insert motor 3 and fasten using 4 M2x6mm screws
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Insert motor 3 and fasten using 4 M2x6mm screws.
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
<div class="video-container">
@@ -286,9 +286,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 4
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Slide over motor holder 4.
- Slide in motor 4.
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
- Fasten motor 4 with 4 M2x6mm screws.
<div class="video-container">
<video controls width="600">
@@ -321,7 +322,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
- Attach the motor horns and again use a M3x6mm horn screw.
- Install both motor horns on the gripper motor. Secure the top horn with a M3x6mm screw; no screws are required for the bottom horn.
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
<div class="video-container">

View File

@@ -0,0 +1,227 @@
# UMI Data with pi0 Relative EE Actions
This guide explains how to train a pi0 policy with UMI-style relative end-effector (EE) actions and deploy it on a real OpenArm robot.
**What we will do:**
1. Prepare the dataset (EE pose + gripper in the action column).
2. Recompute statistics for relative actions.
3. Train pi0 with `derive_state_from_action=true`.
4. Evaluate the trained policy on a real robot.
## Background
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. The key insight from UMI (Chi et al., 2024) is that the action space must include **both EE trajectory and gripper width**, and actions should be expressed as **relative trajectories** (offsets from the current pose).
### Dataset layout
The dataset should have this structure:
| Feature | Shape | Content |
| ------------------------- | --------- | -------------------------------------------------------- |
| `observation.images.cam0` | `[3,H,W]` | Wrist camera image |
| `action` | `[8]` | `[x, y, z, ax, ay, az, proximal, distal]` (EE + gripper) |
No separate `observation.pose` or `observation.joints` columns are needed — the model derives its proprioception state directly from the action column (`derive_state_from_action=true`).
### Why relative actions?
With relative actions, each action in a chunk is an **offset from the current state** rather than an absolute target:
```
relative_action[i] = absolute_action[t + i] state[t]
```
UMI ablations show this is critical: absolute actions achieve only 25% success vs 100% for relative trajectory on the cup arrangement task. Compared to delta actions (each step relative to the previous), relative trajectory avoids error accumulation. See the [Action Representations](action_representations) guide for details.
### `derive_state_from_action`
When `derive_state_from_action=true`, pi0 derives `observation.state` from the action column during training — no separate state column needed. Under the hood:
- `action_delta_indices` extends to `[-1, 0, 1, ..., chunk_size-1]` (one extra leading timestep).
- `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk.
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens.
This implies `use_relative_state=true` and `state_obs_steps=2`.
During **inference**, `DeriveStateFromActionStep` is a no-op — state comes from the robot via forward kinematics. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically.
## Step 1: Recompute Stats
After preparing the dataset with EE pose in the action column, recompute statistics with `derive_state_from_action=true`. This computes relative action and state stats so the normalizer sees offset distributions:
```bash
lerobot-edit-dataset \
--repo-id=glannuzel/grabette-dataset \
--operation=recompute_stats \
--operation.relative_action=true \
--operation.relative_exclude_joints='["proximal", "distal"]' \
--operation.derive_state_from_action=true \
--operation.chunk_size=30 \
--push_to_hub=true
```
| Flag | Purpose |
| ------------------------------- | ------------------------------------------------------------------------------- |
| `relative_action=true` | Compute stats on `action state` (relative actions) |
| `relative_exclude_joints` | Keep gripper dims absolute (they don't benefit from relative encoding) |
| `derive_state_from_action=true` | Derive state from action column (implies `relative_state`, `state_obs_steps=2`) |
| `chunk_size=30` | Must match training chunk size |
## Step 2: Train
```bash
#!/bin/bash
set -euo pipefail
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:${LD_LIBRARY_PATH:-}
DATASET="glannuzel/grabette-dataset"
NUM_PROCESSES=8
echo "=== Training pi0 on $DATASET (UMI relative EE, ${NUM_PROCESSES} GPUs) ==="
accelerate launch --multi_gpu --num_processes=$NUM_PROCESSES \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id="$DATASET" \
--dataset.video_backend=pyav \
--policy.type=pi0 \
--policy.pretrained_path=lerobot/pi0_base \
--policy.repo_id=pepijn/grabette-umi-pi0 \
--policy.chunk_size=30 \
--policy.n_action_steps=30 \
--policy.derive_state_from_action=true \
--use_relative_actions=true \
--policy.relative_exclude_joints='["proximal", "distal"]' \
--batch_size=32 \
--steps=5000 \
--policy.scheduler_decay_steps=5000 \
--policy.dtype=bfloat16 \
--policy.compile_model=false \
--policy.gradient_checkpointing=true \
--policy.device=cuda \
--output_dir=/fsx/pepijn/outputs/grabette-umi \
--job_name=grabette-umi-v2 \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=grabette-umi \
--log_freq=100 \
--save_freq=5000
```
Key flags:
| Flag | Purpose |
| ------------------------------- | ---------------------------------------------------------------------- |
| `derive_state_from_action=true` | Derive proprioception from action column (full UMI mode) |
| `use_relative_actions=true` | Actions are offsets from current state |
| `relative_exclude_joints` | `["proximal", "distal"]` — gripper stays absolute, EE pose is relative |
| `chunk_size=30` | Action horizon: 30 steps (~0.65s at 46 FPS) |
| `n_action_steps=30` | Execute full chunk before replanning |
Note: `derive_state_from_action=true` automatically implies `use_relative_state=true` and `state_obs_steps=2`. No `rename_map` is needed since there are no separate observation columns to rename.
## Step 3: Evaluate
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real OpenArm robot:
```bash
python examples/umi_pi0_relative_ee/evaluate.py
```
Edit `HF_MODEL_ID`, camera index, and robot configuration at the top of the file.
### How inference works
At inference, the training dataset has no `observation.state` — it was derived from actions. The evaluate script provides `observation.state` from the robot via forward kinematics:
1. **Robot → FK** — Arm joint positions → EE pose `[x,y,z,ax,ay,az]`, gripper → `[proximal, distal]`. Combined into `observation.state` (8D).
2. **Preprocessor** (loaded from checkpoint) — `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers previous state, stacks `[prev, current]`, subtracts current → velocity info. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes.
3. **pi0 inference** — Predicts normalized relative action chunk (30 steps).
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, `AbsoluteActionsProcessorStep` adds cached state → absolute EE targets.
5. **IK → Robot** — Absolute `[x,y,z,ax,ay,az]` → arm joint targets with full 6-DOF IK (orientation weight = 1.0). `[proximal, distal]` → direct gripper position commands.
### Latency compensation
Set `LATENCY_SKIP_STEPS` to skip the first few predicted action steps, compensating for system latency:
```python
LATENCY_SKIP_STEPS = 7 # ceil(total_latency_ms / (1000 / FPS))
```
At 46 FPS (~22ms/step) with ~150ms total latency: `ceil(150/22) ≈ 7`. Start with 0 for a safe first test.
## Replay Viewer
Visualize any dataset episode in a browser-based 3D viewer before running on hardware. The viewer shows the EE trajectory overlaid on the OpenArm URDF model.
### Quick start
```bash
python examples/umi_pi0_relative_ee/replay.py
```
### Options
| Flag | Default | Description |
| ----------- | ---------------------------- | ------------------------------------ |
| `--repo-id` | `glannuzel/grabette-dataset` | HuggingFace dataset repo to load |
| `--episode` | `0` | Episode index to replay |
| `--port` | `8765` | HTTP server port |
| `--force` | off | Re-extract trajectory even if cached |
### Viewer controls
The panel in the top-left corner shows live EE coordinates and gripper state. Transport controls:
- **Play / Pause** — toggle automatic playback.
- **Step buttons** (◀ ▶) — advance or rewind one frame.
- **Reset** (⟳) — jump to frame 0.
- **Scrubber** — drag to seek.
- **Speed selector** — 0.25× to 4× playback speed.
### Color legend
| Color | Meaning |
| ------------------ | --------------------------------------------- |
| Red sphere | Current EE position |
| Yellow trail | Past trajectory |
| Dark trail | Future trajectory |
| Orange ring + axes | URDF `ee_target` frame (zero-joint reference) |
## How the Pieces Fit Together
```
Training (derive_state_from_action=true):
DataLoader loads action: [B, 31, 8] (chunk_size=30 + 1 leading)
→ DeriveStateFromActionStep
state = action[:, :2, :] → [B, 2, 8]
action = action[:, 1:, :] → [B, 30, 8]
→ RelativeActionsProcessorStep (action -= state[:, -1, :])
→ RelativeStateProcessorStep (state offsets from current, flatten → [B, 16])
→ NormalizerProcessorStep → pi0 model
Inference:
arm joints → FK → observation.state [8D: x,y,z,ax,ay,az,prox,dist]
DeriveStateFromActionStep (no-op)
RelativeActionsProcessorStep (caches state)
RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens)
NormalizerProcessorStep → pi0 model → relative action chunk [30, 8]
UnnormalizerProcessorStep
AbsoluteActionsProcessorStep (+ cached state → absolute EE)
IK → joint targets → robot
```
## References
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
- [Action Representations](action_representations) — LeRobot guide comparing absolute, relative, and delta actions.
- [pi0 documentation](pi0) — Full pi0 configuration including `use_relative_actions`.
- [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) — EE-space evaluation example this builds on.

View File

@@ -78,7 +78,7 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
robot.connect()
try:

View File

@@ -88,9 +88,8 @@ def main():
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(dataset.hf_dataset)
# You can inspect the dataset using its repr:
print(dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.

View File

@@ -35,9 +35,7 @@ def main():
# Fetch the dataset to replay
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -48,7 +46,7 @@ def main():
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
# Get recorded action from dataset

View File

@@ -67,9 +67,7 @@ def main():
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -80,7 +78,7 @@ def main():
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
# Get recorded action from dataset

View File

@@ -63,6 +63,26 @@ Usage:
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--device=cuda
"""
import logging
@@ -87,21 +107,29 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
)
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.processor.relative_action_processor import to_relative_actions
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
@@ -212,6 +240,35 @@ def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: Tensor,
current_state: Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> Tensor:
"""Convert absolute leftovers into model-space for relative-action RTC policies.
When a policy uses relative actions, the RTC prefix (leftover actions from
the previous chunk) is stored in absolute space. Before feeding it back to
the policy we need to re-express it relative to the *current* robot state
and then re-normalize.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def get_actions(
policy,
robot: RobotWrapper,
@@ -237,7 +294,15 @@ def get_actions(
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
# Only keep .pos joints + camera streams if the policy was trained on positions,
# not the full pos/vel/torque state the robot exposes.
observation_features_hw = {
key: value
for key, value in robot.observation_features().items()
if key.endswith(".pos") or isinstance(value, tuple)
}
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
@@ -255,6 +320,25 @@ def get_actions(
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if relative_step is not None:
if relative_step.action_names is None:
cfg_names = getattr(cfg.policy, "action_feature_names", None)
if cfg_names:
relative_step.action_names = list(cfg_names)
else:
relative_step.action_names = [
k for k in robot.robot.action_features if k.endswith(".pos")
]
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
@@ -297,6 +381,28 @@ def get_actions(
preproceseded_obs = preprocessor(obs_with_policy_features)
# Re-anchor leftover actions for relative-action policies.
# We need the *postprocessed* (absolute) leftover, not the original
# (normalized/relative) one that get_left_over() returns.
if (
prev_actions is not None
and relative_step is not None
and OBS_STATE in obs_with_policy_features
):
with action_queue.lock:
if action_queue.queue is not None:
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
else:
prev_actions_abs = None
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_actions_abs,
current_state=obs_with_policy_features[OBS_STATE],
relative_step=relative_step,
normalizer_step=normalizer_step,
policy_device=policy_device,
)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
@@ -352,6 +458,8 @@ def actor_control(
try:
logger.info("[ACTOR] Starting actor thread")
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
action_interval = 1.0 / cfg.fps
@@ -363,7 +471,7 @@ def actor_control(
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)

View File

@@ -68,9 +68,7 @@ def main():
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -81,7 +79,7 @@ def main():
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
# Get recorded action from dataset

View File

@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for a pi0 model trained with UMI-style relative EE actions
on an OpenArm robot (single right arm, one wrist camera).
Training dataset layout:
observation.images.cam0 [3, 720, 960]
action [x, y, z, ax, ay, az, proximal, distal] (shape 8)
The model uses ``derive_state_from_action=true``, so observation.state is
derived from the action column during training. At inference the state must
be provided by the robot — this script uses FK to compute the current EE
pose and gripper position, which it exposes as ``observation.state``.
Pipeline:
1. Read arm joints from robot → FK → observation.state [x,y,z,ax,ay,az,prox,dist]
2. Read camera image → observation.images.cam0
3. pi0 preprocessor (loaded from checkpoint):
- DeriveStateFromActionStep: no-op at inference (state from robot)
- RelativeActionsProcessorStep: caches current state
- RelativeStateProcessorStep: buffers prev state, stacks [prev,cur],
subtracts current → velocity info, flattens
- NormalizerProcessorStep: normalizes
4. pi0 predicts relative action chunk (30 steps)
5. pi0 postprocessor: unnormalize, add cached state → absolute EE
6. IK: absolute EE [x,y,z,ax,ay,az] → arm joint targets
7. Gripper [proximal, distal] → gripper motor targets
8. Send to robot
Usage:
python evaluate.py
"""
from __future__ import annotations
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.processor import RelativeStateProcessorStep
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
# ---------------------------------------------------------------------------
# Configuration — adapt these to your setup
# ---------------------------------------------------------------------------
FPS = 46
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "red cube"
HF_MODEL_ID = "pepijn223/grabette-umi-pi0"
# Latency compensation: skip this many predicted action steps to account for
# camera + inference + execution latency. Formula: ceil(total_ms / (1000/FPS)).
# At 46 FPS (~22ms/step) with ~150ms total latency: ceil(150/22) ≈ 7.
# Start with 0 for a safe first test, then increase to match measured latency.
LATENCY_SKIP_STEPS = 0
URDF_PATH = "src/lerobot/robots/openarm_follower/urdf/openarm_bimanual_pybullet.urdf"
URDF_EE_FRAME = "openarm_right_ee_target"
IK_POSITION_WEIGHT = 1.0
IK_ORIENTATION_WEIGHT = 1.0
# ---------------------------------------------------------------------------
# Dataset features for inference
#
# The training dataset has only observation.images.cam0 and action.
# observation.state is derived from action during training
# (derive_state_from_action=true) but must be supplied by the robot at
# inference. We define it here so build_dataset_frame can map FK output
# to the right feature.
# ---------------------------------------------------------------------------
DATASET_FEATURES: dict = {
"observation.state": {
"dtype": "float32",
"shape": [8],
"names": ["x", "y", "z", "ax", "ay", "az", "proximal", "distal"],
},
"observation.images.cam0": {
"dtype": "video",
"shape": [3, 720, 960],
"names": ["channels", "height", "width"],
"info": {
"video.height": 720,
"video.width": 960,
"video.codec": "h264",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"video.fps": FPS,
"video.channels": 3,
"has_audio": False,
},
},
"action": {
"dtype": "float32",
"shape": [8],
"names": ["x", "y", "z", "ax", "ay", "az", "proximal", "distal"],
},
"timestamp": {"dtype": "float32", "shape": [1], "names": None},
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
"index": {"dtype": "int64", "shape": [1], "names": None},
"task_index": {"dtype": "int64", "shape": [1], "names": None},
}
# ---------------------------------------------------------------------------
# FK / IK callables
# ---------------------------------------------------------------------------
class JointsToEE:
"""FK: raw robot observation → flat dict matching observation.state names.
Arm joint positions → EE pose [x,y,z,ax,ay,az] via forward kinematics.
Gripper motor positions → [proximal, distal].
Camera images pass through unchanged.
"""
def __init__(self, kinematics: RobotKinematics, arm_motor_names: list[str]):
self.kin = kinematics
self.arm = arm_motor_names
def __call__(self, obs: RobotObservation) -> RobotObservation:
q = np.array([float(obs[f"{m}.pos"]) for m in self.arm])
t = self.kin.forward_kinematics(q)
rot = Rotation.from_matrix(t[:3, :3]).as_rotvec()
out: dict = {
"x": float(t[0, 3]),
"y": float(t[1, 3]),
"z": float(t[2, 3]),
"ax": float(rot[0]),
"ay": float(rot[1]),
"az": float(rot[2]),
"proximal": float(obs["proximal.pos"]),
"distal": float(obs["distal.pos"]),
}
for k, v in obs.items():
if not k.endswith((".pos", ".vel", ".torque")):
out[k] = v
return out
class EEToJoints:
"""IK: policy action dict → motor position dict for the robot.
Reads [x,y,z,ax,ay,az] from the action, runs IK for arm joint targets.
Passes [proximal, distal] as direct gripper position commands.
"""
def __init__(
self,
kinematics: RobotKinematics,
arm_motor_names: list[str],
position_weight: float = 1.0,
orientation_weight: float = 1.0,
):
self.kin = kinematics
self.arm = arm_motor_names
self.pw = position_weight
self.ow = orientation_weight
self.q_curr: np.ndarray | None = None
def __call__(self, args: tuple[RobotAction, RobotObservation]) -> RobotAction:
action, obs = args
q_raw = np.array([float(obs[f"{m}.pos"]) for m in self.arm])
if self.q_curr is None:
self.q_curr = q_raw
t_des = np.eye(4)
t_des[:3, :3] = Rotation.from_rotvec([action["ax"], action["ay"], action["az"]]).as_matrix()
t_des[:3, 3] = [action["x"], action["y"], action["z"]]
q_target = self.kin.inverse_kinematics(
self.q_curr, t_des, position_weight=self.pw, orientation_weight=self.ow
)
self.q_curr = q_target
out: dict = {f"{m}.pos": float(q_target[i]) for i, m in enumerate(self.arm)}
out["proximal.pos"] = float(action["proximal"])
out["distal.pos"] = float(action["distal"])
return out
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
camera_config = {
"cam0": OpenCVCameraConfig(index_or_path=0, width=960, height=720, fps=FPS),
}
robot_config = OpenArmFollowerConfig(
port="can0",
id="right_openarm",
side="right",
cameras=camera_config,
max_relative_target=8.0,
gripper_port="/dev/ttyUSB0",
)
robot = OpenArmFollower(robot_config)
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
arm_motor_names = list(robot.bus.motors.keys())
kinematics = RobotKinematics(
urdf_path=URDF_PATH,
target_frame_name=URDF_EE_FRAME,
joint_names=arm_motor_names,
)
fk = JointsToEE(kinematics, arm_motor_names)
ik = EEToJoints(kinematics, arm_motor_names, IK_POSITION_WEIGHT, IK_ORIENTATION_WEIGHT)
dataset = LeRobotDataset.create(
repo_id="tmp/openarm_eval_scratch",
fps=FPS,
features=DATASET_FEATURES,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
robot.connect()
listener, events = init_keyboard_listener()
init_rerun(session_name="openarm_umi_pi0_relative_ee_evaluate")
try:
if not robot.is_connected:
raise ValueError("Robot is not connected!")
log_say("Starting policy execution")
for step in relative_state_steps:
step.reset()
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
robot_action_processor=ik,
robot_observation_processor=fk,
)
finally:
robot.disconnect()
listener.stop()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Replay a dataset episode in EE frame using a browser-based URDF viewer.
Extracts ``observation.pose`` from the dataset, saves a trajectory JSON file,
then launches a local HTTP server and opens the replay viewer. The trajectory
is re-centered so frame 0 starts at the OpenArm ``openarm_right_ee_target``
EE tip (zero-joint pose).
Usage:
python replay.py
python replay.py --episode 3 --repo-id myuser/mydata
"""
from __future__ import annotations
import argparse
import http.server
import json
import os
import threading
import webbrowser
from pathlib import Path
VIEWER_DIR = Path(__file__).resolve().parents[2] / "src/lerobot/robots/openarm_follower/urdf"
TRAJECTORY_FILENAME = "trajectory_ep0.json"
def extract_trajectory(repo_id: str, episode: int, output_path: Path) -> dict:
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(repo_id, episodes=[episode])
poses = dataset.select_columns("observation.pose")
actions = dataset.select_columns("action")
frames = []
for i in range(dataset.num_frames):
p = poses[i]["observation.pose"]
a = actions[i]["action"]
frames.append(
{
"x": float(p[0]),
"y": float(p[1]),
"z": float(p[2]),
"ax": float(p[3]),
"ay": float(p[4]),
"az": float(p[5]),
"proximal": float(a[0]),
"distal": float(a[1]),
}
)
payload = {"fps": dataset.fps, "num_frames": dataset.num_frames, "frames": frames}
with open(output_path, "w") as f:
json.dump(payload, f)
print(f"Extracted {dataset.num_frames} frames at {dataset.fps} FPS → {output_path}")
return payload
# ---------------------------------------------------------------------------
# Viewer mode
# ---------------------------------------------------------------------------
def serve_and_open(directory: Path, port: int = 8765):
os.chdir(directory)
handler = http.server.SimpleHTTPRequestHandler
httpd = http.server.HTTPServer(("", port), handler)
url = f"http://localhost:{port}/replay_viewer.html"
print(f"Serving at {url}")
threading.Thread(target=lambda: webbrowser.open(url), daemon=True).start()
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nServer stopped.")
httpd.server_close()
def run_viewer(args):
trajectory_path = VIEWER_DIR / TRAJECTORY_FILENAME
if not trajectory_path.exists() or args.force:
extract_trajectory(args.repo_id, args.episode, trajectory_path)
else:
print(f"Using cached trajectory at {trajectory_path} (pass --force to re-extract)")
serve_and_open(VIEWER_DIR, args.port)
def main():
parser = argparse.ArgumentParser(description="Replay a dataset episode in EE frame (URDF viewer)")
parser.add_argument("--repo-id", default="glannuzel/grabette-dataset")
parser.add_argument("--episode", type=int, default=0)
parser.add_argument("--port", type=int, default=8765)
parser.add_argument("--force", action="store_true", help="Re-extract trajectory even if cached")
args = parser.parse_args()
run_viewer(args)
if __name__ == "__main__":
main()

View File

@@ -99,7 +99,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
@@ -145,6 +145,7 @@ wallx = [
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"lerobot[peft]",
@@ -305,7 +306,8 @@ default.extend-ignore-identifiers-re = [
"thw",
"inpt",
"ROBOTIS",
"OT_VALUE"
"OT_VALUE",
"metalness",
]
# TODO: Uncomment when ready to use

View File

@@ -27,7 +27,8 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are
# looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)

View File

@@ -115,6 +115,17 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
def state_delta_indices(self) -> list | None: # type: ignore[type-arg]
"""Delta indices specifically for observation.state.
When not None, overrides ``observation_delta_indices`` for the
``observation.state`` key only. Useful for loading state history
(e.g. ``[-1, 0]`` for UMI-style relative proprioception) without
also loading multiple image timesteps.
"""
return None
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
__all__ = [
"EpisodeAwareSampler",
"ImageTransforms",
"ImageTransformsConfig",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"StreamingLeRobotDataset",
]

View File

@@ -13,9 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import numpy as np
from lerobot.datasets.io_utils import load_image_as_numpy
from lerobot.utils.constants import ACTION, OBS_STATE
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -624,3 +629,232 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
return aggregated_stats
def _get_valid_chunk_starts(episode_indices: np.ndarray, chunk_size: int) -> np.ndarray:
"""Return all start indices where a chunk of ``chunk_size`` stays within one episode."""
total = len(episode_indices)
if total < chunk_size:
return np.array([], dtype=np.int64)
max_start = total - chunk_size
starts = np.arange(max_start + 1)
valid = episode_indices[starts] == episode_indices[starts + chunk_size - 1]
return starts[valid]
def _compute_relative_chunk_batch(
start_indices: np.ndarray,
all_actions: np.ndarray,
all_states: np.ndarray,
chunk_size: int,
relative_mask: np.ndarray,
) -> np.ndarray:
"""Vectorised relative-action computation for a batch of start indices.
Returns an ``(N * chunk_size, action_dim)`` float32 array.
"""
if len(start_indices) == 0:
return np.empty((0, all_actions.shape[1]), dtype=np.float32)
offsets = np.arange(chunk_size)
frame_idx = start_indices[:, None] + offsets[None, :]
chunks = all_actions[frame_idx].copy()
states = all_states[start_indices]
mask_dim = len(relative_mask)
chunks[:, :, :mask_dim] -= states[:, None, :mask_dim] * relative_mask[None, None, :]
return chunks.reshape(-1, all_actions.shape[1])
def compute_relative_action_stats(
hf_dataset,
features: dict,
chunk_size: int,
exclude_joints: list[str] | None = None,
num_workers: int = 0,
) -> dict[str, np.ndarray]:
"""Compute normalization statistics for relative actions over the full dataset.
Iterates *all* valid action chunks (within single episodes), converts them to
relative actions (action current_state), and computes per-dimension
statistics suitable for normalization.
Args:
hf_dataset: The underlying HuggingFace dataset with "action",
"observation.state", and "episode_index" columns.
features: Dataset feature metadata (must contain "action" with "shape"
and optionally "names").
chunk_size: Number of consecutive frames per action chunk.
exclude_joints: Joint names whose dimensions should remain absolute
(not converted to relative actions).
num_workers: Number of parallel threads for computation. Values ≤1
mean single-threaded. Numpy releases the GIL so threads give
real parallelism here.
Returns:
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
Raises:
ValueError: If the dataset has fewer frames than ``chunk_size``.
RuntimeError: If no valid (single-episode) chunks are found.
"""
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
if exclude_joints is None:
exclude_joints = []
action_dim = features[ACTION]["shape"][0]
action_names = features.get(ACTION, {}).get("names")
mask_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints,
action_names=action_names,
)
relative_mask = np.array(mask_step._build_mask(action_dim), dtype=np.float32)
logging.info("Loading action/state data for relative action stats...")
all_actions = np.array(hf_dataset[ACTION], dtype=np.float32)
all_states = np.array(hf_dataset[OBS_STATE], dtype=np.float32)
episode_indices = np.array(hf_dataset["episode_index"])
valid_starts = _get_valid_chunk_starts(episode_indices, chunk_size)
if len(valid_starts) == 0:
raise RuntimeError(
f"No valid chunks found (total_frames={len(episode_indices)}, chunk_size={chunk_size})"
)
effective_workers = max(num_workers, 1)
logging.info(
f"Computing relative action stats from {len(valid_starts)} chunks "
f"(chunk_size={chunk_size}, workers={effective_workers})"
)
batch_size = 50_000
batches = [valid_starts[i : i + batch_size] for i in range(0, len(valid_starts), batch_size)]
running_stats = RunningQuantileStats()
if num_workers > 1:
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=num_workers) as pool:
futures = [
pool.submit(
_compute_relative_chunk_batch,
batch,
all_actions,
all_states,
chunk_size,
relative_mask,
)
for batch in batches
]
for future in as_completed(futures):
running_stats.update(future.result())
else:
for batch in batches:
running_stats.update(
_compute_relative_chunk_batch(batch, all_actions, all_states, chunk_size, relative_mask)
)
stats = running_stats.get_statistics()
excluded_dims = int(len(relative_mask) - relative_mask.sum())
total_frames = len(valid_starts) * chunk_size
logging.info(
f"Relative action stats ({len(valid_starts)} chunks, {total_frames} frames): "
f"relative_dims={int(relative_mask.sum())}/{len(relative_mask)} (excluded={excluded_dims}), "
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}, "
f"q01={stats['q01'].mean():.4f}, q99={stats['q99'].mean():.4f}"
)
return stats
def compute_relative_state_stats(
hf_dataset,
features: dict,
state_obs_steps: int = 2,
exclude_joints: list[str] | None = None,
source_key: str = OBS_STATE,
) -> dict[str, np.ndarray]:
"""Compute normalization statistics for observation.state after relative conversion.
For UMI-style relative proprioception with ``state_obs_steps`` timesteps,
each state observation becomes a stack of offsets from the current timestep:
``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``.
The stats are computed over the flattened ``[state_obs_steps * state_dim]``
vector that the model actually sees after ``prepare_state`` flattening.
Args:
hf_dataset: The HuggingFace dataset with the source column and
"episode_index" columns.
features: Dataset feature metadata.
state_obs_steps: Number of observation timesteps (must be >= 2).
exclude_joints: State dimension names to keep absolute.
source_key: Column to read data from. Defaults to "observation.state".
When ``derive_state_from_action=True``, pass ``ACTION`` to read
from the action column instead.
Returns:
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
"""
from lerobot.processor.relative_action_processor import RelativeStateProcessorStep
if exclude_joints is None:
exclude_joints = []
state_dim = features[source_key]["shape"][0]
state_names = features.get(source_key, {}).get("names")
mask_step = RelativeStateProcessorStep(
enabled=True,
exclude_joints=exclude_joints,
state_names=state_names,
)
relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32)
logging.info(f"Loading data from '{source_key}' for relative state stats...")
all_states = np.array(hf_dataset[source_key], dtype=np.float32)
episode_indices = np.array(hf_dataset["episode_index"])
# Build all valid windows of length state_obs_steps within each episode
n = len(all_states)
if n < state_obs_steps:
raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}")
max_start = n - state_obs_steps
starts = np.arange(max_start + 1)
valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1]
valid_starts = starts[valid]
if len(valid_starts) == 0:
raise RuntimeError("No valid state windows found within single episodes")
offsets = np.arange(state_obs_steps)
mask_dim = len(relative_mask)
running_stats = RunningQuantileStats()
batch_size = 50_000
for i in range(0, len(valid_starts), batch_size):
batch_starts = valid_starts[i : i + batch_size]
frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps]
windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim]
# Subtract current (last) timestep from all timesteps for masked dims
current = windows[:, -1:, :] # [N, 1, state_dim]
windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :]
# Flatten to [N, state_obs_steps * state_dim] (same as prepare_state)
flattened = windows.reshape(len(batch_starts), -1)
running_stats.update(flattened)
stats = running_stats.get_statistics()
excluded_dims = int(mask_dim - relative_mask.sum())
logging.info(
f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): "
f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), "
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}"
)
return stats

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from pathlib import Path
import numpy as np
@@ -43,16 +44,24 @@ from lerobot.datasets.utils import (
check_version_compatibility,
flatten_dict,
get_safe_version,
has_legacy_hub_download_metadata,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
"""Metadata container for a LeRobot dataset.
Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and
``episodes/`` parquet files that describe a dataset's structure, content,
and statistics.
"""
def __init__(
self,
repo_id: str,
@@ -61,33 +70,57 @@ class LeRobotDatasetMetadata:
force_cache_sync: bool = False,
metadata_buffer_size: int = 10,
):
"""Load or download metadata for an existing LeRobot dataset.
Attempts to load metadata from local disk. If files are missing or
``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from
the Hub.
Args:
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
root: Local directory for the dataset. When provided, Hub downloads
are materialized directly into this directory. When omitted,
existing local datasets are still looked up under
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
revision: Git revision (branch, tag, or commit hash). Defaults to
the current codebase version.
force_cache_sync: If ``True``, re-download metadata from the Hub
even when local files exist.
metadata_buffer_size: Number of episode metadata records to buffer
in memory before flushing to parquet.
"""
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.writer = None
self._requested_root = Path(root) if root is not None else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self._pq_writer = None
self.latest_episode = None
self.metadata_buffer: list[dict] = []
self.metadata_buffer_size = metadata_buffer_size
self._metadata_buffer: list[dict] = []
self._metadata_buffer_size = metadata_buffer_size
self._finalized = False
try:
if force_cache_sync:
if force_cache_sync or (
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
):
raise FileNotFoundError
self.load_metadata()
self._load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
self._pull_from_repo(allow_patterns="meta/")
self._load_metadata()
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for episode_dict in self._metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
@@ -96,40 +129,50 @@ class LeRobotDatasetMetadata:
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
first_ep = self._metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
table = pa.Table.from_pydict(combined_dict)
if not self.writer:
if not self._pq_writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(
self._pq_writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self.writer.write_table(table)
self._pq_writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
self.latest_episode = self._metadata_buffer[-1]
self._metadata_buffer.clear()
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
self._flush_metadata_buffer()
writer = getattr(self, "writer", None)
writer = getattr(self, "_pq_writer", None)
if writer is not None:
writer.close()
self.writer = None
self._pq_writer = None
def finalize(self) -> None:
"""Flush metadata buffer and close the parquet writer.
Idempotent — safe to call multiple times.
"""
if getattr(self, "_finalized", False):
return
self._close_writer()
self._finalized = True
def __del__(self):
"""
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
"""
self._close_writer()
"""Safety net: flush and close parquet writer on garbage collection."""
# During interpreter shutdown, referenced objects may already be collected.
with contextlib.suppress(Exception):
self.finalize()
def load_metadata(self):
def _load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
@@ -137,22 +180,38 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
def pull_from_repo(
def _pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
if self._requested_root is None:
self.root = Path(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
cache_dir=HF_LEROBOT_HUB_CACHE,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
return
self._requested_root.mkdir(exist_ok=True, parents=True)
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
local_dir=self._requested_root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
self.root = self._requested_root
@property
def url_root(self) -> str:
"""Hugging Face Hub URL root for this dataset."""
return f"hf://datasets/{self.repo_id}"
@property
@@ -161,6 +220,17 @@ class LeRobotDatasetMetadata:
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
"""Return the relative parquet file path for the given episode index.
Args:
ep_index: Zero-based episode index.
Returns:
Path to the parquet file containing this episode's data.
Raises:
IndexError: If ``ep_index`` is out of range.
"""
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
@@ -174,6 +244,19 @@ class LeRobotDatasetMetadata:
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
"""Return the relative video file path for the given episode and video key.
Args:
ep_index: Zero-based episode index.
vid_key: Feature key identifying the video stream
(e.g. ``'observation.images.laptop'``).
Returns:
Path to the video file containing this episode's frames.
Raises:
IndexError: If ``ep_index`` is out of range.
"""
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
@@ -277,6 +360,17 @@ class LeRobotDatasetMetadata:
return None
def save_episode_tasks(self, tasks: list[str]):
"""Register tasks for the current episode and persist to disk.
New tasks that do not already exist in the dataset are assigned
sequential task indices and appended to the tasks parquet file.
Args:
tasks: List of unique task descriptions in natural language.
Raises:
ValueError: If ``tasks`` contains duplicates.
"""
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
@@ -336,8 +430,8 @@ class LeRobotDatasetMetadata:
latest_path = (
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
if self.writer is None
else self.writer.where
if self._pq_writer is None
else self._pq_writer.where
)
if Path(latest_path).exists():
@@ -359,10 +453,10 @@ class LeRobotDatasetMetadata:
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
# Add to buffer
self.metadata_buffer.append(episode_dict)
self._metadata_buffer.append(episode_dict)
self.latest_episode = episode_dict
if len(self.metadata_buffer) >= self.metadata_buffer_size:
if len(self._metadata_buffer) >= self._metadata_buffer_size:
self._flush_metadata_buffer()
def save_episode(
@@ -373,6 +467,20 @@ class LeRobotDatasetMetadata:
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
"""Persist episode metadata, update dataset info, and aggregate stats.
Writes the episode's metadata to the buffered parquet writer, increments
the total episode/frame counters in ``info.json``, and merges the
episode's statistics into the running dataset statistics.
Args:
episode_index: Zero-based index of the episode being saved.
episode_length: Number of frames in this episode.
episode_tasks: List of task descriptions for this episode.
episode_stats: Per-feature statistics for this episode.
episode_metadata: Additional metadata (chunk/file indices, frame
ranges, video timestamps, etc.).
"""
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
@@ -479,10 +587,36 @@ class LeRobotDatasetMetadata:
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
"""Create metadata for a new LeRobot dataset from scratch.
Initializes the ``info.json`` file on disk with the provided feature
schema and dataset settings. No episode data is written yet.
Args:
repo_id: Repository identifier (e.g. ``'user/my_dataset'``).
fps: Frames per second used during data collection.
features: Feature specification dict mapping feature names to their
type/shape metadata.
robot_type: Optional robot type string stored in metadata.
root: Local directory for the dataset. Defaults to
``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist.
use_videos: If ``True``, visual modalities are encoded as MP4 videos.
metadata_buffer_size: Number of episode metadata records to buffer
before flushing to parquet.
chunks_size: Max number of files per chunk directory. ``None`` uses
the default.
data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the
default.
video_files_size_in_mb: Max video file size in MB. ``None`` uses the
default.
Returns:
A new :class:`LeRobotDatasetMetadata` instance.
"""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj._requested_root = Path(root) if root is not None else None
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
@@ -510,8 +644,9 @@ class LeRobotDatasetMetadata:
)
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
obj._pq_writer = None
obj.latest_episode = None
obj.metadata_buffer = []
obj.metadata_buffer_size = metadata_buffer_size
obj._metadata_buffer = []
obj._metadata_buffer_size = metadata_buffer_size
obj._finalized = False
return obj

View File

@@ -0,0 +1,288 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
from collections.abc import Callable
from pathlib import Path
import datasets
import torch
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
check_delta_timestamps,
get_delta_indices,
get_hf_features_from_features,
)
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
load_nested_dataset,
)
from lerobot.datasets.video_utils import decode_video_frames
class DatasetReader:
"""Encapsulates read-side state and methods for LeRobotDataset.
Owns: hf_dataset, _absolute_to_relative_idx, delta_indices.
"""
def __init__(
self,
meta: LeRobotDatasetMetadata,
root: Path,
episodes: list[int] | None,
tolerance_s: float,
video_backend: str,
delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None,
):
"""Initialize the reader with metadata, filtering, and transform config.
The HF dataset is not loaded here — call :meth:`try_load` or
:meth:`load_and_activate` afterward.
Args:
meta: Dataset metadata instance.
root: Local dataset root directory.
episodes: Optional list of episode indices to select. ``None``
means all episodes.
tolerance_s: Timestamp synchronization tolerance in seconds.
video_backend: Video decoding backend identifier.
delta_timestamps: Optional dict mapping feature keys to lists of
relative timestamp offsets for temporal context windows.
image_transforms: Optional torchvision v2 transform applied to
visual features.
"""
self._meta = meta
self.root = root
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
self._image_transforms = image_transforms
self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None
# Setup delta_indices (doesn't depend on hf_dataset)
self.delta_indices = None
if delta_timestamps is not None:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
self.hf_dataset = self._load_hf_dataset()
except (FileNotFoundError, NotADirectoryError):
self.hf_dataset = None
return False
if not self._check_cached_episodes_sufficient():
self.hf_dataset = None
return False
self._build_index_mapping()
return True
def load_and_activate(self) -> None:
"""Load HF dataset from disk and build index mapping. Call after data is on disk."""
self.hf_dataset = self._load_hf_dataset()
self._build_index_mapping()
def _build_index_mapping(self) -> None:
"""Build absolute-to-relative index mapping from loaded hf_dataset."""
self._absolute_to_relative_idx = None
if self.episodes is not None and self.hf_dataset is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
if self.episodes is not None and self.hf_dataset is not None:
return len(self.hf_dataset)
return self._meta.total_frames
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self._meta.total_episodes
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
for ep_idx in self.hf_dataset.unique("episode_index")
}
if self.episodes is None:
requested_episodes = set(range(self._meta.total_episodes))
else:
requested_episodes = set(self.episodes)
if not requested_episodes.issubset(available_episodes):
return False
if len(self._meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self._meta.video_keys:
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
return True
def get_episodes_file_paths(self) -> list[Path]:
"""Return deduplicated file paths (data + video) for selected episodes.
Used to build the ``allow_patterns`` list for ``snapshot_download``.
"""
episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes))
fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self._meta.video_keys) > 0:
video_files = [
str(self._meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self._meta.video_keys
for ep_idx in episodes
]
fpaths += video_files
# episodes are stored in the same files, so we return unique paths only
fpaths = list(set(fpaths))
return fpaths
def _get_query_indices(
self, abs_idx: int, ep_idx: int
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
"""Compute query indices for delta timestamps."""
ep = self._meta.episodes[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
query_indices = {
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = {
f"{key}_is_pad": torch.BoolTensor(
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
return query_indices, padding
def _get_query_timestamps(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self._meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""Query dataset for indices across keys, skipping video keys."""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self._meta.video_keys:
continue
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault.
"""
ep = self._meta.episodes[ep_idx]
item = {}
for vid_key, query_ts in query_timestamps.items():
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
item[vid_key] = frames.squeeze(0)
return item
def get_item(self, idx) -> dict:
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
``idx`` is a *relative* index into the (possibly episode-filtered)
HF dataset, **not** the absolute frame index stored in the ``index``
column. The absolute index is retrieved from the row itself.
"""
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
abs_idx = item["index"].item()
query_indices = None
if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
if len(self._meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self._image_transforms is not None:
image_keys = self._meta.camera_keys
for cam in image_keys:
item[cam] = self._image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item

View File

@@ -37,7 +37,12 @@ import torch
from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.compute_stats import (
aggregate_stats,
compute_episode_stats,
compute_relative_action_stats,
compute_relative_state_stats,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.io_utils import (
get_parquet_file_size_in_mb,
@@ -56,7 +61,7 @@ from lerobot.datasets.utils import (
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -891,7 +896,7 @@ def _copy_and_reindex_episodes_metadata(
total_frames += src_episode["length"]
dst_meta._close_writer()
dst_meta.finalize()
dst_meta.info.update(
{
@@ -1533,6 +1538,147 @@ def modify_tasks(
return dataset
def recompute_stats(
dataset: LeRobotDataset,
skip_image_video: bool = True,
relative_action: bool = False,
relative_exclude_joints: list[str] | None = None,
chunk_size: int = 50,
num_workers: int = 0,
relative_state: bool = False,
relative_exclude_state_joints: list[str] | None = None,
state_obs_steps: int = 2,
derive_state_from_action: bool = False,
) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes.
Args:
dataset: The LeRobotDataset to recompute stats for.
skip_image_video: If True (default), only recompute stats for numeric features
(action, state, etc.) and keep existing image/video stats unchanged.
relative_action: If True, compute action stats in relative space by
iterating all valid action chunks and subtracting the current state.
This matches the normalization distribution the model sees during
training with ``use_relative_actions=True``.
relative_exclude_joints: Joint names to exclude from relative conversion when
relative_action=True. These dims keep absolute stats.
chunk_size: Action chunk size used for relative stats computation. Should match
``policy.chunk_size``. Only used when ``relative_action=True``.
num_workers: Number of parallel threads for relative action stats computation.
Values ≤1 mean single-threaded. Only used when ``relative_action=True``.
relative_state: If True, compute observation.state stats in relative space
(multi-timestep offsets from current). This matches the normalization
the model sees during training with ``use_relative_state=True``.
relative_exclude_state_joints: State dim names to exclude from relative conversion.
state_obs_steps: Number of observation timesteps for relative state stats.
Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``.
derive_state_from_action: If True, compute relative state stats from the
action column instead of observation.state. Implies ``relative_state=True``
and ``state_obs_steps=2``.
Returns:
The same dataset with updated stats.
"""
if derive_state_from_action:
relative_state = True
state_obs_steps = 2
features = dataset.meta.features
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
numeric_features = {
k: v
for k, v in features.items()
if v["dtype"] not in ["image", "video", "string"] and k not in meta_keys
}
if skip_image_video:
features_to_compute = numeric_features
else:
features_to_compute = {
k: v for k, v in features.items() if v["dtype"] != "string" and k not in meta_keys
}
# When relative_action is enabled, compute action stats via chunk-based sampling
# (matching what the model sees during training) and skip action in the
# per-episode pass below.
relative_action_stats = None
if relative_action and ACTION in features and OBS_STATE in features:
if relative_exclude_joints is None:
relative_exclude_joints = ["gripper"]
relative_action_stats = compute_relative_action_stats(
hf_dataset=dataset.hf_dataset,
features=features,
chunk_size=chunk_size,
exclude_joints=relative_exclude_joints,
num_workers=num_workers,
)
features_to_compute.pop(ACTION, None)
# When relative_state is enabled, compute state stats over the flattened
# multi-timestep relative representation (matching what the model sees).
relative_state_stats = None
if relative_state and (OBS_STATE in features or derive_state_from_action):
source_key = ACTION if derive_state_from_action else OBS_STATE
relative_state_stats = compute_relative_state_stats(
hf_dataset=dataset.hf_dataset,
features=features,
state_obs_steps=state_obs_steps,
exclude_joints=relative_exclude_state_joints,
source_key=source_key,
)
features_to_compute.pop(OBS_STATE, None)
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
df = pd.read_parquet(parquet_path)
for ep_idx in sorted(df["episode_index"].unique()):
ep_df = df[df["episode_index"] == ep_idx]
episode_data = {}
for key in numeric_keys:
if key in ep_df.columns:
values = ep_df[key].values
if hasattr(values[0], "__len__"):
episode_data[key] = np.stack(values)
else:
episode_data[key] = np.array(values)
ep_stats = compute_episode_stats(episode_data, features_to_compute)
all_episode_stats.append(ep_stats)
if features_to_compute and not all_episode_stats:
logging.warning("No episode stats computed")
return dataset
new_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
if relative_action_stats is not None:
new_stats[ACTION] = relative_action_stats
if relative_state_stats is not None:
new_stats[OBS_STATE] = relative_state_stats
# Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in new_stats:
new_stats[key] = value
write_stats(new_stats, dataset.root)
dataset.meta.stats = new_stats
logging.info("Stats recomputed successfully")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path | None = None,

View File

@@ -0,0 +1,634 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Private writer component for LeRobotDataset. Handles sequential recording (episode buffer, ParquetWriter, image writer, video encoding)."""
from __future__ import annotations
import concurrent.futures
import contextlib
import logging
import shutil
import tempfile
from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.datasets.compute_stats import compute_episode_stats
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
get_hf_features_from_features,
validate_episode_buffer,
validate_frame,
)
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.io_utils import (
embed_images,
get_file_size_in_mb,
load_episodes,
write_info,
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
concatenate_video_files,
encode_video_frames,
get_video_duration_in_s,
)
logger = logging.getLogger(__name__)
def _encode_video_worker(
video_key: str,
episode_index: int,
root: Path,
fps: int,
vcodec: str = "libsvtav1",
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
)
shutil.rmtree(img_dir)
return temp_path
class DatasetWriter:
"""Encapsulates write-side state and methods for LeRobotDataset.
Owns: episode_buffer, image_writer, _pq_writer (ParquetWriter), _latest_episode,
_current_file_start_frame, _streaming_encoder, _episodes_since_last_encoding, _recorded_frames.
"""
def __init__(
self,
meta: LeRobotDatasetMetadata,
root: Path,
vcodec: str,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0,
):
"""Initialize the writer with metadata, codec, and encoding config.
Args:
meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence).
root: Local dataset root directory.
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
encoder_threads: Threads per encoder instance. ``None`` for auto.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
for real-time encoding. ``None`` disables streaming mode.
initial_frames: Starting frame count (non-zero when resuming).
"""
self._meta = meta
self._root = root
self._vcodec = vcodec
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
# Writer state
self.image_writer: AsyncImageWriter | None = None
self.episode_buffer: dict = self._create_episode_buffer()
self._pq_writer: pq.ParquetWriter | None = None
self._latest_episode: dict | None = None
self._current_file_start_frame: int | None = None
self._episodes_since_last_encoding: int = 0
self._recorded_frames: int = initial_frames
self._finalized = False
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self._meta.total_episodes if episode_index is None else episode_index
ep_buffer = {}
ep_buffer["size"] = 0
ep_buffer["task"] = []
for key in self._meta.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
write_image(image, fpath, compress_level=compress_level)
else:
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
def add_frame(self, frame: dict) -> None:
"""
Add a single frame to the current episode buffer.
Apart from images written to a temporary directory, nothing is written to disk
until ``save_episode()`` is called.
The caller must provide all user-defined features plus ``"task"``, and must
not provide ``"timestamp"`` or ``"frame_index"``; those are computed
automatically.
"""
# Convert torch to numpy if needed
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
validate_frame(frame, self._meta.features)
if self.episode_buffer is None:
self.episode_buffer = self._create_episode_buffer()
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = frame_index / self._meta.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task"))
# Start streaming encoder on first frame of episode
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self._meta.video_keys),
temp_dir=self._root,
)
# Add frame features to episode_buffer
for key in frame:
if key not in self._meta.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'."
)
if self._meta.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
self._streaming_encoder.feed_frame(key, frame[key])
self.episode_buffer[key].append(None)
elif self._meta.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
def save_episode(
self,
episode_data: dict | None = None,
parallel_encoding: bool = True,
) -> None:
"""Save the current episode in self.episode_buffer to disk."""
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Update tasks and task indices with new tasks if any
self._meta.save_episode_tasks(episode_tasks)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks])
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
has_video_keys = len(self._meta.video_keys) > 0
use_streaming = self._streaming_encoder is not None and has_video_keys
use_batched_encoding = self._batch_encoding_size > 1
if use_streaming:
non_video_buffer = {
k: v
for k, v in episode_buffer.items()
if self._meta.features.get(k, {}).get("dtype") not in ("video",)
}
non_video_features = {k: v for k, v in self._meta.features.items() if v["dtype"] != "video"}
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
else:
ep_stats = compute_episode_stats(episode_buffer, self._meta.features)
ep_metadata = self._save_episode_data(episode_buffer)
if use_streaming:
streaming_results = self._streaming_encoder.finish_episode()
for video_key in self._meta.video_keys:
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
elif has_video_keys and not use_batched_encoding:
num_cameras = len(self._meta.video_keys)
if parallel_encoding and num_cameras > 1:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
future_to_key = {
executor.submit(
_encode_video_worker,
video_key,
episode_index,
self._root,
self._meta.fps,
self._vcodec,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
}
results = {}
for future in concurrent.futures.as_completed(future_to_key):
video_key = future_to_key[future]
try:
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logger.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self._meta.video_keys:
temp_path = results[video_key]
ep_metadata.update(
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
)
else:
for video_key in self._meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
if has_video_keys and use_batched_encoding:
self._episodes_since_last_encoding += 1
if self._episodes_since_last_encoding == self._batch_encoding_size:
start_ep = self._meta.total_episodes - self._batch_encoding_size
end_ep = self._meta.total_episodes
self._batch_save_episode_video(start_ep, end_ep)
self._episodes_since_last_encoding = 0
if episode_data is None:
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""Batch save videos for multiple episodes."""
if end_episode is None:
end_episode = self._meta.total_episodes
logger.info(
f"Batch encoding {self._batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
)
chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"]
file_idx = self._meta.episodes[start_episode]["data/file_index"]
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logger.info(f"Encoding videos for episode {ep_idx}")
if (
self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
or self._meta.episodes[ep_idx]["data/file_index"] != file_idx
):
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"]
file_idx = self._meta.episodes[ep_idx]["data/file_index"]
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
episode_df = pd.read_parquet(episode_df_path)
video_ep_metadata = {}
for video_key in self._meta.video_keys:
video_ep_metadata.update(self._save_episode_video(video_key, ep_idx))
video_ep_metadata.pop("episode_index")
video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes(
dtype_backend="pyarrow"
)
episode_df = episode_df.combine_first(video_ep_df)
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
def _save_episode_data(self, episode_buffer: dict) -> dict:
"""Save episode data to a parquet file."""
# Use metadata features as the authoritative schema
hf_features = get_hf_features_from_features(self._meta.features)
ep_dict = {key: episode_buffer[key] for key in hf_features}
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
ep_num_frames = len(ep_dataset)
if self._latest_episode is None:
chunk_idx, file_idx = 0, 0
global_frame_index = 0
self._current_file_start_frame = 0
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
latest_ep = self._meta.episodes[-1]
global_frame_index = latest_ep["dataset_to_index"]
chunk_idx = latest_ep["data/chunk_index"]
file_idx = latest_ep["data/file_index"]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
self._current_file_start_frame = global_frame_index
else:
latest_ep = self._latest_episode
chunk_idx = latest_ep["data/chunk_index"]
file_idx = latest_ep["data/file_index"]
global_frame_index = latest_ep["index"][-1] + 1
latest_path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_file_size_in_mb(latest_path)
frames_in_current_file = global_frame_index - self._current_file_start_frame
av_size_per_frame = (
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
)
if latest_size_in_mb + av_size_per_frame * ep_num_frames >= self._meta.data_files_size_in_mb:
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
self.close_writer()
self._current_file_start_frame = global_frame_index
ep_dict["data/chunk_index"] = chunk_idx
ep_dict["data/file_index"] = file_idx
path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
table = ep_dataset.with_format("arrow")[:]
if not self._pq_writer:
self._pq_writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self._pq_writer.write_table(table)
metadata = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"dataset_from_index": global_frame_index,
"dataset_to_index": global_frame_index + ep_num_frames,
}
self._latest_episode = {**ep_dict, **metadata}
self._recorded_frames += ep_num_frames
return metadata
def _save_episode_video(
self,
video_key: str,
episode_index: int,
temp_path: Path | None = None,
) -> dict:
if temp_path is None:
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
else:
ep_path = temp_path
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
if (
episode_index == 0
or self._meta.latest_episode is None
or f"videos/{video_key}/chunk_index" not in self._meta.latest_episode
):
chunk_idx, file_idx = 0, 0
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
old_chunk_idx = self._meta.episodes[-1][f"videos/{video_key}/chunk_index"]
old_file_idx = self._meta.episodes[-1][f"videos/{video_key}/file_index"]
chunk_idx, file_idx = update_chunk_file_indices(
old_chunk_idx, old_file_idx, self._meta.chunks_size
)
latest_duration_in_s = 0.0
new_path = self._root / self._meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
latest_ep = self._meta.latest_episode
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0]
file_idx = latest_ep[f"videos/{video_key}/file_index"][0]
latest_path = self._root / self._meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
latest_size_in_mb = get_file_size_in_mb(latest_path)
latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0]
if latest_size_in_mb + ep_size_in_mb >= self._meta.video_files_size_in_mb:
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
new_path = self._root / self._meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
latest_duration_in_s = 0.0
else:
concatenate_video_files(
[latest_path, ep_path],
latest_path,
)
# Remove temporary directory
shutil.rmtree(str(ep_path.parent))
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key)
write_info(self._meta.info, self._meta.root)
metadata = {
"episode_index": episode_index,
f"videos/{video_key}/chunk_index": chunk_idx,
f"videos/{video_key}/file_index": file_idx,
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
return metadata
def clear_episode_buffer(self, delete_images: bool = True) -> None:
"""Discard the current episode buffer and optionally delete temp images.
Args:
delete_images: If ``True``, remove temporary image directories
written for the current episode.
"""
# Cancel streaming encoder if active
if self._streaming_encoder is not None:
self._streaming_encoder.cancel_episode()
if delete_images:
if self.image_writer is not None:
self._wait_image_writer()
episode_index = self.episode_buffer["episode_index"]
# episode_index is `int` when freshly created, but becomes `np.ndarray` after
# save_episode() mutates the buffer. Handle both types here.
if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for cam_key in self._meta.image_keys:
img_dir = self._get_image_file_dir(episode_index, cam_key)
if img_dir.is_dir():
shutil.rmtree(img_dir)
self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
"""Start an :class:`AsyncImageWriter` for background image persistence.
Args:
num_processes: Number of subprocesses. ``0`` means threads only.
num_threads: Number of threads per process.
"""
if isinstance(self.image_writer, AsyncImageWriter):
logger.warning(
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
)
self.image_writer = AsyncImageWriter(
num_processes=num_processes,
num_threads=num_threads,
)
def stop_image_writer(self) -> None:
"""Stop the image writer (needed before pickling the dataset for DataLoader)."""
if self.image_writer is not None:
self.image_writer.stop()
self.image_writer = None
def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish."""
if self.image_writer is not None:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker(
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
)
def close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
if self._pq_writer is not None:
self._pq_writer.close()
self._pq_writer = None
def flush_pending_videos(self) -> None:
"""Flush any pending video encoding (streaming or batch).
For streaming encoding: closes the encoder.
For batch encoding: encodes any remaining episodes that haven't been batch-encoded yet.
"""
if self._streaming_encoder is not None:
self._streaming_encoder.close()
elif self._episodes_since_last_encoding > 0:
start_ep = self._meta.total_episodes - self._episodes_since_last_encoding
end_ep = self._meta.total_episodes
logger.info(
f"Encoding remaining {self._episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self._batch_save_episode_video(start_ep, end_ep)
def cancel_pending_videos(self) -> None:
"""Cancel any in-progress streaming encoding without flushing."""
if self._streaming_encoder is not None:
self._streaming_encoder.cancel_episode()
def cleanup_interrupted_episode(self, episode_index: int) -> None:
"""Remove temporary image directories for an interrupted episode."""
for key in self._meta.video_keys:
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logger.debug(
f"Cleaning up interrupted episode images for episode {episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
def finalize(self) -> None:
"""Flush all pending work and release all resources.
Idempotent — safe to call multiple times.
"""
if getattr(self, "_finalized", False):
return
# 1. Wait for async image writes to complete, then stop
if self.image_writer is not None:
self.image_writer.wait_until_done()
self.image_writer.stop()
self.image_writer = None
# 2. Flush pending video encoding (streaming or batch)
self.flush_pending_videos()
# 3. Close own parquet writer
self.close_writer()
# 4. Finalize metadata (idempotent)
self._meta.finalize()
self._finalized = True
def __del__(self):
"""Safety net: release resources on garbage collection."""
# During interpreter shutdown, referenced objects may already be collected.
with contextlib.suppress(Exception):
self.finalize()

View File

@@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -52,12 +52,15 @@ def resolve_delta_timestamps(
returns `None` if the resulting dict is empty.
"""
delta_timestamps = {}
state_delta = getattr(cfg, "state_delta_indices", None)
for key in ds_meta.features:
if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
if key == OBS_STATE and state_delta is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in state_delta]
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:

View File

@@ -365,6 +365,10 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
def validate_frame(frame: dict, features: dict) -> None:
# DEFAULT_FEATURES (timestamp, frame_index, episode_index, index, task_index) are
# auto-populated by the recording pipeline (add_frame / save_episode) and must not
# be supplied by the caller. Excluding them here means any frame dict that contains
# these keys will be rejected as extra features.
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)

View File

@@ -32,10 +32,10 @@ def safe_stop_image_writer(func):
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
writer = getattr(dataset, "writer", None) if dataset else None
if writer is not None and writer.image_writer is not None:
logger.warning("Waiting for image writer to terminate...")
image_writer.stop()
writer.image_writer.stop()
raise e
return wrapper

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ import torch
import torch.utils
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import VideoFrame
from lerobot.utils.constants import HF_LEROBOT_HOME
@@ -125,7 +126,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
features.update(
{
k: v
for k, v in get_hf_features_from_features(dataset.features).items()
if k not in self.disabled_features
}
)
return features
@property

View File

@@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory to use for downloading/writing files.
root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub
metadata is resolved through a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list.
image_transforms (Callable | None, optional): Transform to apply to image data.
@@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self.streaming_from_local = root is not None
self.image_transforms = image_transforms
@@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self.root.mkdir(exist_ok=True, parents=True)
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
)
self.root = self.meta.root
self.revision = self.meta.revision
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)

View File

@@ -18,6 +18,7 @@ import importlib.resources
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import datasets
@@ -101,6 +102,18 @@ DEFAULT_FEATURES = {
}
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
``snapshot_download(local_dir=...)`` stores lightweight metadata under
``<local_dir>/.cache/huggingface/download/``. The presence of this
directory is a reliable indicator that the dataset was downloaded with
the old non-revision-safe ``local_dir`` mode and should be re-fetched
through the snapshot cache instead.
"""
return (root / ".cache" / "huggingface" / "download").exists()
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0

View File

@@ -741,6 +741,7 @@ class StreamingVideoEncoder:
self._video_paths: dict[str, Path] = {}
self._dropped_frames: dict[str, int] = {}
self._episode_active = False
self._closed = False
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
"""Start encoder threads for a new episode.
@@ -895,8 +896,11 @@ class StreamingVideoEncoder:
def close(self) -> None:
"""Close the encoder, canceling any in-progress episode."""
if self._closed:
return
if self._episode_active:
self.cancel_episode()
self._closed = True
def _cleanup(self) -> None:
"""Clean up queues and thread tracking dicts."""
@@ -1063,43 +1067,19 @@ class VideoEncodingManager:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
writer = self.dataset.writer
if writer is not None:
if exc_type is not None and writer._streaming_encoder is not None:
writer.cancel_pending_videos()
if streaming_encoder is not None:
# Handle streaming encoder cleanup
if exc_type is not None:
streaming_encoder.cancel_episode()
streaming_encoder.close()
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
if exc_type is not None:
logger.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logger.info("Recording stopped. Encoding remaining episodes...")
# finalize() handles flush_pending_videos + parquet + metadata
self.dataset.finalize()
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logger.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset._batch_save_episode_video(start_ep, end_ep)
# Finalize the dataset to properly close all writers
self.dataset.finalize()
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and streaming_encoder is None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logger.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and writer._streaming_encoder is None:
writer.cleanup_interrupted_episode(self.dataset.num_episodes)
else:
self.dataset.finalize()
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"

View File

@@ -15,6 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
@@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"PI0FastConfig",

View File

@@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -58,6 +59,29 @@ from lerobot.utils.constants import (
)
def _reconnect_relative_absolute_steps(
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
) -> None:
"""Wire AbsoluteActionsProcessorStep.relative_step to the RelativeActionsProcessorStep after deserialization.
After a policy is loaded from disk, the preprocessor and postprocessor are reconstructed
independently from their configs. AbsoluteActionsProcessorStep needs a live reference to
the RelativeActionsProcessorStep so it can read the cached state at inference time.
That reference is not serializable, so we re-establish it here after loading.
"""
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
RelativeActionsProcessorStep,
)
relative_step = next((s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep)), None)
if relative_step is None:
return
for step in postprocessor.steps:
if isinstance(step, AbsoluteActionsProcessorStep) and step.relative_step is None:
step.relative_step = relative_step
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""
Retrieves a policy class by its registered name.
@@ -67,8 +91,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -87,6 +110,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
@@ -147,8 +174,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -163,6 +190,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
@@ -263,26 +292,26 @@ def make_pre_post_processors(
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
),
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
# Create a new processor based on policy type
if isinstance(policy_cfg, TDMPCConfig):
@@ -309,6 +338,16 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
@@ -470,6 +509,13 @@ def make_policy(
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names for relative_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)

View File

@@ -0,0 +1,37 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_multi_task_dit import MultiTaskDiTConfig
from .modeling_multi_task_dit import MultiTaskDiTPolicy
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]

View File

@@ -0,0 +1,256 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
"""
n_obs_steps: int = 2 # Number of observation steps for temporal context
horizon: int = 32 # Number of action steps to predict
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
# Objective Selection
objective: str = "diffusion" # "diffusion" or "flow_matching"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Training/Optimizer
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self):
super().__post_init__()
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
self._validate()
def _validate(self):
"""Validate configuration parameters."""
# Objective validation
if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
logging.warning(
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
self.image_crop_shape,
self.image_resize_shape,
)
self.image_crop_shape = None
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# If the configured crop doesn't fit, disable cropping instead of erroring.
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
# image_ft.shape is (C, H, W)
effective_h, effective_w = (
self.image_resize_shape
if self.image_resize_shape is not None
else (image_ft.shape[1], image_ft.shape[2])
)
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
logging.warning(
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
self.image_crop_shape,
effective_h,
effective_w,
key,
)
self.image_crop_shape = None
break
if len(self.image_features) > 0:
first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_ft.shape:
raise ValueError(
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
)
@property
def is_diffusion(self) -> bool:
return self.objective == "diffusion"
@property
def is_flow_matching(self) -> bool:
return self.objective == "flow_matching"
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,803 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Task Diffusion Transformer (DiT) Policy
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
Supports both diffusion and flow matching objectives for action generation.
References:
- https://arxiv.org/abs/2507.05331
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
"""
import math
from collections import deque
from typing import TYPE_CHECKING
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import CLIPTextModel, CLIPVisionModel
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy --
class MultiTaskDiTPolicy(PreTrainedPolicy):
config_class = MultiTaskDiTConfig
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
super().__init__(config)
config.validate_features()
self.config = config
self._queues = None
self.observation_encoder = ObservationEncoder(config)
conditioning_dim = self.observation_encoder.conditioning_dim
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
action_dim = config.action_feature.shape[0]
horizon = config.horizon
if config.is_diffusion:
self.objective = DiffusionObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported objective: {config.objective}")
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
non_vision_params = []
vision_encoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "observation_encoder.vision_encoder" in name:
vision_encoder_params.append(param)
else:
non_vision_params.append(param)
return [
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
},
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
actions = self._generate_actions(batch)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
batch = self._prepare_batch(batch)
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
# -- Observation Encoders --
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
self.model = CLIPVisionModel.from_pretrained(self.model_name)
self.num_non_spatial_tokens = 1
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token."""
outputs = self.model(pixel_values=x, output_hidden_states=False)
cls_token = outputs.last_hidden_state[:, 0]
b, embed_dim = cls_token.shape
return cls_token.reshape(b, embed_dim, 1, 1)
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode pre-tokenized text to feature vectors."""
# Ensure inputs are on the same device as the model
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output
return self.projection(clip_features)
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES]
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_rgb_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
cam_images_flat = self._apply_preprocessing(cam_images_flat)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
cam_visual_features = cam_features.flatten(start_dim=1)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1)
return combined_features.flatten(start_dim=1)
# -- Transformer Components --
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero."""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers."""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape # noqa: N806
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.rope(q, k)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
return self.out_proj(attn_out)
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero."""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
self._initialize_weights()
def _initialize_weights(self):
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep)
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
hidden_seq = self.input_proj(x)
if self.pos_embedding is not None:
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features)
return self.output_proj(hidden_seq)
# -- Objectives --
class DiffusionObjective(nn.Module):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch[ACTION]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(nn.Module):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch[ACTION]
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1)
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x

View File

@@ -0,0 +1,105 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Adding a batch dimension.
3. Tokenizing the language task description (if present).
4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale.
2. Moving the data to the CPU.
Args:
config: The configuration object for the Multi-Task DiT policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -17,6 +17,65 @@ It is designed as a **Vision-Language-Action model for general robot control**.
---
## Relative Actions
π₀ supports training with **relative actions**, where the model learns relative offsets
from the current robot state instead of absolute joint positions. This mirrors the
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
### How it works
1. **During preprocessing**, absolute actions are converted to relative offsets:
`relative = action - state` (for selected joints).
2. The relative actions are normalized using statistics computed from the relative distribution.
3. **During postprocessing**, predicted relative actions are converted back to absolute:
`absolute = relative + state`.
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
### Configuration
| Parameter | Type | Default | Description |
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
### Training example
```bash
python -m lerobot.scripts.lerobot_train \
--policy.type=pi0 \
--dataset.repo_id=your_org/your_dataset \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
When `use_relative_actions=true`, the training script automatically:
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
- Replaces the standard action stats with relative stats for normalization
- Broadcasts these stats across all ranks in distributed training
### Recomputing stats for an existing dataset
If you want to precompute relative action stats offline, use `recompute_stats` from
`lerobot.datasets.dataset_tools`:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.dataset_tools import recompute_stats
dataset = LeRobotDataset("your_org/your_dataset")
dataset = recompute_stats(
dataset,
relative_action=True,
relative_exclude_joints=["gripper"],
)
```
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀ paper:

View File

@@ -50,6 +50,35 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Relative actions: converts absolute actions to relative (relative to state).
use_relative_actions: bool = False
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Relative state (UMI-style relative proprioception): converts multi-timestep
# observation.state to offsets from the current timestep, providing velocity info.
# Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to
# max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim.
use_relative_state: bool = False
state_obs_steps: int = 1
relative_exclude_state_joints: list[str] = field(default_factory=list)
# Populated at runtime from dataset metadata by make_policy.
state_feature_names: list[str] | None = None
# Derive observation.state from the action column (UMI-style).
# When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1],
# DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state,
# and strips the extra timestep from the action chunk.
# Implies use_relative_state=True and state_obs_steps=2.
derive_state_from_action: bool = False
# Latency compensation: skip this many steps from the start of each predicted
# action chunk during inference. E.g. at 10Hz with ~200ms total latency,
# latency_skip_steps=2 compensates for the delay.
latency_skip_steps: int = 0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
@@ -99,6 +128,10 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()
if self.derive_state_from_action:
self.use_relative_state = True
self.state_obs_steps = 2
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
@@ -114,6 +147,13 @@ class PI0Config(PreTrainedConfig):
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
if self.use_relative_state and self.state_obs_steps < 2:
raise ValueError(
"use_relative_state requires state_obs_steps >= 2 "
f"(got {self.state_obs_steps}). Set state_obs_steps=2 for "
"UMI-style relative proprioception."
)
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
@@ -159,8 +199,16 @@ class PI0Config(PreTrainedConfig):
def observation_delta_indices(self) -> None:
return None
@property
def state_delta_indices(self) -> list[int] | None:
if self.state_obs_steps >= 2:
return list(range(-(self.state_obs_steps - 1), 1))
return None
@property
def action_delta_indices(self) -> list:
if self.derive_state_from_action:
return [-1] + list(range(self.chunk_size))
return list(range(self.chunk_size))
@property

View File

@@ -1230,8 +1230,11 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
"""Flatten multi-timestep state and pad to max_state_dim."""
state = batch[OBS_STATE]
if state.ndim == 3:
state = state.flatten(start_dim=1)
state = pad_vector(state, self.config.max_state_dim)
return state
def prepare_action(self, batch):
@@ -1250,7 +1253,8 @@ class PI0Policy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
skip = self.config.latency_skip_steps
actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))

View File

@@ -21,14 +21,18 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeriveStateFromActionStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RelativeStateProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -126,7 +130,25 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
derive_state_step = DeriveStateFromActionStep(
enabled=getattr(config, "derive_state_from_action", False),
)
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
relative_state_step = RelativeStateProcessorStep(
enabled=getattr(config, "use_relative_state", False),
exclude_joints=getattr(config, "relative_exclude_state_joints", []),
state_names=getattr(config, "state_feature_names", None),
)
# Order: DeriveStateFromAction extracts state from the extended action chunk,
# then relative_action uses current state[t] for subtraction,
# then relative_state converts the multi-timestep state to offsets.
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
@@ -138,6 +160,9 @@ def make_pi0_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
derive_state_step,
relative_step,
relative_state_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -149,6 +174,7 @@ def make_pi0_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -17,6 +17,48 @@ It is designed as a **Vision-Language-Action model with open-world generalizatio
---
## Relative Actions
π₀.₅ supports training with **relative actions**, where the model learns relative offsets
from the current robot state instead of absolute joint positions. This mirrors the
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
### How it works
1. **During preprocessing**, absolute actions are converted to relative offsets:
`relative = action - state` (for selected joints).
2. The relative actions are normalized using statistics computed from the relative distribution.
3. **During postprocessing**, predicted relative actions are converted back to absolute:
`absolute = relative + state`.
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
### Configuration
| Parameter | Type | Default | Description |
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
### Training example
```bash
python -m lerobot.scripts.lerobot_train \
--policy.type=pi05 \
--dataset.repo_id=your_org/your_dataset \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
When `use_relative_actions=true`, the training script automatically:
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
- Replaces the standard action stats with relative stats for normalization
- Broadcasts these stats across all ranks in distributed training
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:

View File

@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Relative actions: converts absolute actions to relative (relative to state).
use_relative_actions: bool = False
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None

View File

@@ -24,6 +24,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
@@ -31,6 +32,7 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -125,10 +127,17 @@ def make_pi05_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
relative_step,
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
@@ -150,6 +159,7 @@ def make_pi05_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -41,6 +41,13 @@ class PI0FastConfig(PreTrainedConfig):
max_action_dim: int = 32
max_action_tokens: int = 256
# Relative actions: converts absolute actions to relative (relative to state).
use_relative_actions: bool = False
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None

View File

@@ -24,6 +24,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -32,6 +33,7 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -125,12 +127,24 @@ def make_pi0_fast_pre_post_processors(
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# Pi0Fast order: relative → normalize → tokenize → model → unnormalize → absolute
# This matches pi0/pi0.5: RelativeActionsProcessorStep runs first on raw absolute actions,
# caching the raw state. NormalizerProcessorStep then normalizes the raw relative actions,
# so the normalizer (and action tokenizer) sees delta values, relative stats are required.
# NOTE: RelativeActionsProcessorStep only modifies the action in the transition; it reads
# state from the observation but does not change it. NormalizerProcessorStep still runs
# before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep, so the state tokenizer
# continues to receive normalized state in [-1, 1] as expected.
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -156,6 +170,7 @@ def make_pi0_fast_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -55,7 +55,7 @@ class SmolVLAConfig(PreTrainedConfig):
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Converts joint dimensions to relative values with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False

View File

@@ -75,6 +75,15 @@ from .policy_robot_bridge import (
PolicyActionToRobotActionProcessorStep,
RobotActionToPolicyActionProcessorStep,
)
from .relative_action_processor import (
AbsoluteActionsProcessorStep,
DeriveStateFromActionStep,
RelativeActionsProcessorStep,
RelativeStateProcessorStep,
to_absolute_actions,
to_relative_actions,
to_relative_state,
)
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
@@ -100,6 +109,10 @@ __all__ = [
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"DeriveStateFromActionStep",
"RelativeActionsProcessorStep",
"RelativeStateProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep",
@@ -129,6 +142,9 @@ __all__ = [
"transition_to_batch",
"TransitionKey",
"TruncatedProcessorStep",
"to_absolute_actions",
"to_relative_actions",
"to_relative_state",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]

View File

@@ -0,0 +1,367 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_STATE
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .pipeline import ProcessorStep, ProcessorStepRegistry
# Re-export for backward compatibility
__all__ = [
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"DeriveStateFromActionStep",
"RelativeActionsProcessorStep",
"AbsoluteActionsProcessorStep",
"RelativeStateProcessorStep",
"to_relative_actions",
"to_absolute_actions",
"to_relative_state",
]
def to_relative_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert absolute actions to relative: relative = action - state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
# Align state to the same device/dtype as actions. _last_state is cached before
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
if state.device != actions.device or state.dtype != actions.dtype:
state = state.to(device=actions.device, dtype=actions.dtype)
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] -= state_offset
return actions
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert relative actions back to absolute: absolute = relative + state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
# Align state to the same device/dtype as actions. _last_state is cached before
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
if state.device != actions.device or state.dtype != actions.dtype:
state = state.to(device=actions.device, dtype=actions.dtype)
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] += state_offset
return actions
@ProcessorStepRegistry.register("derive_state_from_action_processor")
@dataclass
class DeriveStateFromActionStep(ProcessorStep):
"""Derives 2-step observation.state from the action chunk (UMI-style).
Expects action with one extra leading timestep: [B, chunk_size+1, D]
from action_delta_indices = [-1, 0, 1, ..., chunk_size-1].
Extracts [action[t-1], action[t]] as state and strips the extra timestep.
No-op during inference (state comes from robot).
"""
enabled: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
action = transition.get(TransitionKey.ACTION)
if action is None or action.ndim < 3:
return transition
new_transition = transition.copy()
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
new_obs[OBS_STATE] = action[..., :2, :]
new_transition[TransitionKey.ACTION] = action[..., 1:, :]
new_transition[TransitionKey.OBSERVATION] = new_obs
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes:
enabled: Whether to apply the relative conversion.
exclude_joints: Joint names to keep absolute (not converted to relative).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask = []
for name in self.action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
raw_state = observation.get(OBS_STATE) if observation else None
# When state_delta_indices loads multi-timestep state [B, n_obs, D],
# use only the current (last) timestep for relative action conversion.
if raw_state is not None:
state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state
else:
state = None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"exclude_joints": self.exclude_joints,
"action_names": self.action_names,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert multi-timestep absolute state to relative (offset from current timestep).
Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims.
The last (current) timestep becomes zeros for masked dims.
Args:
state: (..., n_obs, state_dim) — last timestep is the reference (current).
mask: Which dims to convert. Can be shorter than state_dim.
"""
mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device)
dims = mask_t.shape[0]
current = state[..., -1:, :] # (..., 1, state_dim)
state = state.clone()
state[..., :dims] -= current[..., :dims] * mask_t
return state
@ProcessorStepRegistry.register("relative_state_processor")
@dataclass
class RelativeStateProcessorStep(ProcessorStep):
"""Converts observation.state to relative (offset from current timestep).
UMI-style relative proprioception: each state timestep is expressed as
an offset from the current EE pose, providing velocity information.
During training (multi-timestep input from ``state_delta_indices``):
``state[..., t, :] -= state[..., -1, :]`` — subtract current from all.
During inference (single timestep): buffers the previous state and stacks
``[previous, current]`` before applying the relative conversion, producing
the same ``[n_obs, D]`` shape the model expects.
Attributes:
enabled: Whether to apply the relative conversion.
exclude_joints: Joint/dim names to keep absolute.
state_names: State dimension names from dataset metadata.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
state_names: list[str] | None = None
_previous_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, state_dim: int) -> list[bool]:
if not self.exclude_joints or self.state_names is None:
return [True] * state_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * state_dim
mask = []
for name in self.state_names[:state_dim]:
state_name = str(name).lower()
is_excluded = any(token == state_name or token in state_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < state_dim:
mask.extend([True] * (state_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
if state is None:
return transition
new_transition = transition.copy()
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
mask = self._build_mask(state.shape[-1])
if state.ndim >= 3:
# [B, n_obs, D] — multi-timestep (training with state_delta_indices)
relative = to_relative_state(state, mask)
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D]
elif state.ndim == 2:
# [B, D] — single timestep (inference): buffer previous and stack
current = state
if self._previous_state is None:
self._previous_state = current.clone()
prev = self._previous_state
if prev.device != current.device or prev.dtype != current.dtype:
prev = prev.to(device=current.device, dtype=current.dtype)
stacked = torch.stack([prev, current], dim=-2) # [B, 2, D]
relative = to_relative_state(stacked, mask)
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D]
self._previous_state = current.clone()
new_transition[TransitionKey.OBSERVATION] = new_obs
return new_transition
def reset(self) -> None:
"""Reset the state buffer. Call at episode boundaries during inference."""
self._previous_state = None
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"exclude_joints": self.exclude_joints,
"state_names": self.state_names,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts relative actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted relative offsets are converted back to absolute positions for execution.
Reads the cached state from its paired RelativeActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
relative_step: Reference to the paired RelativeActionsProcessorStep that caches state.
"""
enabled: bool = False
relative_step: RelativeActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
if self.relative_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired RelativeActionsProcessorStep "
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
)
if self.relative_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
mask = self.relative_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions(
action, self.relative_step._last_state, mask
)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features

View File

@@ -563,7 +563,7 @@ class ReplayBuffer:
)
# Start writing images if needed
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
lerobot_dataset.writer.start_image_writer(num_processes=0, num_threads=3)
# Convert transitions into episodes and frames
@@ -603,10 +603,10 @@ class ReplayBuffer:
lerobot_dataset.save_episode()
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
if lerobot_dataset.has_pending_frames():
lerobot_dataset.save_episode()
lerobot_dataset.stop_image_writer()
lerobot_dataset.writer.stop_image_writer()
lerobot_dataset.finalize()
return lerobot_dataset

View File

@@ -752,8 +752,7 @@ def replay_trajectory(
episodes=[cfg.dataset.replay_episode],
download_videos=False,
)
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
_, info = env.reset()

View File

@@ -39,19 +39,31 @@ class BiOpenArmFollower(Robot):
super().__init__(config)
self.config = config
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras,
cameras=left_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
can_bitrate=config.left_arm_config.can_bitrate,
can_data_bitrate=config.left_arm_config.can_data_bitrate,
motor_config=config.left_arm_config.motor_config,
gripper_port=config.left_arm_config.gripper_port,
gripper_motor_ids=config.left_arm_config.gripper_motor_ids,
position_kd=config.left_arm_config.position_kd,
position_kp=config.left_arm_config.position_kp,
joint_limits=config.left_arm_config.joint_limits,
@@ -63,13 +75,15 @@ class BiOpenArmFollower(Robot):
port=config.right_arm_config.port,
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
cameras=right_cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
can_bitrate=config.right_arm_config.can_bitrate,
can_data_bitrate=config.right_arm_config.can_data_bitrate,
motor_config=config.right_arm_config.motor_config,
gripper_port=config.right_arm_config.gripper_port,
gripper_motor_ids=config.right_arm_config.gripper_motor_ids,
position_kd=config.right_arm_config.position_kd,
position_kp=config.right_arm_config.position_kp,
joint_limits=config.right_arm_config.joint_limits,
@@ -93,13 +107,10 @@ class BiOpenArmFollower(Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -139,13 +150,17 @@ class BiOpenArmFollower(Robot):
def get_observation(self) -> RobotObservation:
obs_dict = {}
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
# Add "right_" prefix
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
return obs_dict

View File

@@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
from ..config import RobotConfig
@@ -28,3 +29,6 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras shared across both arms.
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -28,7 +28,8 @@ LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
"joint_5": (-85.0, 85.0),
"joint_6": (-40.0, 40.0),
"joint_7": (-80.0, 80.0),
"gripper": (-65.0, 0.0),
"proximal": (0.0, 100.0),
"distal": (0.0, 100.0),
}
RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
@@ -39,7 +40,8 @@ RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
"joint_5": (-85.0, 85.0),
"joint_6": (-40.0, 40.0),
"joint_7": (-80.0, 80.0),
"gripper": (-65.0, 0.0),
"proximal": (0.0, 100.0),
"distal": (0.0, 100.0),
}
@@ -73,13 +75,8 @@ class OpenArmFollowerConfigBase:
# Camera configurations
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Motor configuration for OpenArms (7 DOF per arm)
# Arm motor configuration (7 DOF, Damiao on CAN bus)
# Maps motor names to (send_can_id, recv_can_id, motor_type)
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms uses 4 types of motors:
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
# - DM4340P and DM4340 for shoulder rotation and elbow
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
motor_config: dict[str, tuple[int, int, str]] = field(
default_factory=lambda: {
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
@@ -89,19 +86,18 @@ class OpenArmFollowerConfigBase:
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
}
)
# MIT control parameters for position control (used in send_action)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
position_kp: list[float] = field(
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0]
)
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3])
# UMI-style gripper (Feetech STS3215 on serial bus)
gripper_port: str = "/dev/ttyUSB0"
gripper_motor_ids: dict[str, int] = field(default_factory=lambda: {"proximal": 1, "distal": 2})
# Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'.
# If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety.
# MIT control parameters for the 7 arm joints
position_kp: list[float] = field(default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0])
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3])
# Joint limits. Can be overridden via CLI or by setting config.side to 'left' or 'right'.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"joint_1": (-5.0, 5.0),
@@ -111,7 +107,8 @@ class OpenArmFollowerConfigBase:
"joint_5": (-5.0, 5.0),
"joint_6": (-5.0, 5.0),
"joint_7": (-5.0, 5.0),
"gripper": (-5.0, 0.0),
"proximal": (0.0, 100.0),
"distal": (0.0, 100.0),
}
)

View File

@@ -22,6 +22,7 @@ from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.motors.feetech import FeetechMotorsBus, OperatingMode
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
@@ -38,8 +39,7 @@ logger = logging.getLogger(__name__)
class OpenArmFollower(Robot):
"""
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
The arm uses Damiao motors in MIT control mode.
OpenArms Follower Robot: 7 DOF Damiao arm (CAN) + UMI-style Feetech gripper (serial).
"""
config_class = OpenArmFollowerConfig
@@ -49,19 +49,17 @@ class OpenArmFollower(Robot):
super().__init__(config)
self.config = config
# Arm motors
motors: dict[str, Motor] = {}
# Arm motors (Damiao on CAN bus)
arm_motors: dict[str, Motor] = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(
send_id, motor_type_str, MotorNormMode.DEGREES
) # Always use degrees for Damiao motors
motor = Motor(send_id, motor_type_str, MotorNormMode.DEGREES)
motor.recv_id = recv_id
motor.motor_type_str = motor_type_str
motors[motor_name] = motor
arm_motors[motor_name] = motor
self.bus = DamiaoMotorsBus(
port=self.config.port,
motors=motors,
motors=arm_motors,
calibration=self.calibration,
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
@@ -69,6 +67,17 @@ class OpenArmFollower(Robot):
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
# Gripper motors (Feetech STS3215 on serial bus)
gripper_motors: dict[str, Motor] = {
name: Motor(motor_id, "sts3215", MotorNormMode.RANGE_0_100)
for name, motor_id in config.gripper_motor_ids.items()
}
self.gripper_bus = FeetechMotorsBus(
port=config.gripper_port,
motors=gripper_motors,
calibration=self.calibration,
)
if config.side is not None:
if config.side == "left":
config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS
@@ -84,7 +93,6 @@ class OpenArmFollower(Robot):
)
logger.info(f"Values used for joint limits: {config.joint_limits}.")
# Initialize cameras
self.cameras = make_cameras_from_configs(config.cameras)
@property
@@ -93,8 +101,10 @@ class OpenArmFollower(Robot):
features: dict[str, type] = {}
for motor in self.bus.motors:
features[f"{motor}.pos"] = float
features[f"{motor}.vel"] = float # Add this
features[f"{motor}.torque"] = float # Add this
features[f"{motor}.vel"] = float
features[f"{motor}.torque"] = float
for motor in self.gripper_bus.motors:
features[f"{motor}.pos"] = float
return features
@property
@@ -116,8 +126,11 @@ class OpenArmFollower(Robot):
@property
def is_connected(self) -> bool:
"""Check if robot is connected."""
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
return (
self.bus.is_connected
and self.gripper_bus.is_connected
and all(cam.is_connected for cam in self.cameras.values())
)
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
@@ -127,12 +140,12 @@ class OpenArmFollower(Robot):
We assume that at connection time, the arms are in a safe rest position,
and torque can be safely disabled to run calibration if needed.
"""
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
# Run calibration if needed
logger.info(f"Connecting gripper on {self.config.gripper_port}...")
self.gripper_bus.connect()
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
@@ -144,7 +157,7 @@ class OpenArmFollower(Robot):
self.configure()
if self.is_calibrated:
if self.bus.is_calibrated:
self.bus.set_zero_position()
self.bus.enable_torque()
@@ -153,47 +166,39 @@ class OpenArmFollower(Robot):
@property
def is_calibrated(self) -> bool:
"""Check if robot is calibrated."""
return self.bus.is_calibrated
return self.bus.is_calibrated and self.gripper_bus.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArms robot.
Run calibration for both the Damiao arm and Feetech gripper.
The calibration procedure:
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position
4. Record range of motion for each joint
5. Save calibration
Arm calibration: set zero position with arm hanging, ±90° default range.
Gripper calibration: SO100-style half-turn homing + range recording.
"""
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
self.bus.write_calibration(self.calibration)
self.gripper_bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque()
# Step 1: Set zero position
# --- Arm calibration (Damiao) ---
self.bus.disable_torque()
input(
"\nCalibration: Set Zero Position)\n"
"\nCalibration: Set Zero Position\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
# Set current position as zero for all motors
self.bus.set_zero_position()
logger.info("Arm zero position set.")
logger.info("Setting range: -90° to +90° for safety by default for all joints")
for motor_name, motor in self.bus.motors.items():
self.calibration[motor_name] = MotorCalibration(
id=motor.id,
@@ -202,17 +207,52 @@ class OpenArmFollower(Robot):
range_min=-90,
range_max=90,
)
self.bus.write_calibration(self.calibration)
# --- Gripper calibration (Feetech) ---
self.gripper_bus.disable_torque()
for motor in self.gripper_bus.motors:
self.gripper_bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input("Move gripper to the middle of its range of motion and press ENTER....")
homing_offsets = self.gripper_bus.set_half_turn_homings()
gripper_motor_names = list(self.gripper_bus.motors.keys())
print(
f"Move gripper joints ({', '.join(gripper_motor_names)}) through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.gripper_bus.record_ranges_of_motion(gripper_motor_names)
for motor_name, m in self.gripper_bus.motors.items():
self.calibration[motor_name] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=homing_offsets[motor_name],
range_min=range_mins[motor_name],
range_max=range_maxes[motor_name],
)
self.gripper_bus.write_calibration(self.calibration)
self._save_calibration()
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
"""Configure motors with appropriate settings."""
# TODO(Steven, Pepijn): Slightly different from what it is happening in the leader
"""Configure both arm (Damiao) and gripper (Feetech) motors."""
with self.bus.torque_disabled():
self.bus.configure_motors()
with self.gripper_bus.torque_disabled():
self.gripper_bus.configure_motors()
for motor in self.gripper_bus.motors:
self.gripper_bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.gripper_bus.write("P_Coefficient", motor, 16)
self.gripper_bus.write("I_Coefficient", motor, 0)
self.gripper_bus.write("D_Coefficient", motor, 32)
self.gripper_bus.write("Max_Torque_Limit", motor, 500)
self.gripper_bus.write("Protection_Current", motor, 250)
self.gripper_bus.write("Overload_Torque", motor, 25)
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -220,25 +260,23 @@ class OpenArmFollower(Robot):
@check_if_not_connected
def get_observation(self) -> RobotObservation:
"""
Get current observation from robot including position, velocity, and torque.
Reads all motor states (pos/vel/torque) in one CAN refresh cycle
instead of 3 separate reads.
"""
"""Read all motor states from arm (CAN) and gripper (serial), plus cameras."""
start = time.perf_counter()
obs_dict: dict[str, Any] = {}
# Arm motors (Damiao) — pos/vel/torque in one CAN refresh cycle
states = self.bus.sync_read_all_states()
for motor in self.bus.motors:
state = states.get(motor, {})
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
# Capture images from cameras
# Gripper motors (Feetech) — position only
gripper_positions = self.gripper_bus.sync_read("Present_Position")
for motor, val in gripper_positions.items():
obs_dict[f"{motor}.pos"] = val
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
@@ -258,86 +296,76 @@ class OpenArmFollower(Robot):
custom_kd: dict[str, float] | None = None,
) -> RobotAction:
"""
Send action command to robot.
The action magnitude may be clipped based on safety limits.
Send action command to robot. Arm joints go to Damiao CAN bus,
gripper joints go to Feetech serial bus.
Args:
action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos")
custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0})
custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0})
action: Dictionary with motor positions (e.g., "joint_1.pos", "proximal.pos")
custom_kp: Optional custom kp gains per arm motor
custom_kd: Optional custom kd gains per arm motor
Returns:
The action actually sent (potentially clipped)
"""
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Apply joint limit clipping to arm
# Apply joint limit clipping
for motor_name, position in goal_pos.items():
if motor_name in self.config.joint_limits:
min_limit, max_limit = self.config.joint_limits[motor_name]
clipped_position = max(min_limit, min(max_limit, position))
if clipped_position != position:
logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°")
logger.debug(f"Clipped {motor_name} from {position:.2f} to {clipped_position:.2f}")
goal_pos[motor_name] = clipped_position
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
# Split into arm and gripper actions
arm_motors = set(self.bus.motors.keys())
gripper_motors = set(self.gripper_bus.motors.keys())
arm_goal = {k: v for k, v in goal_pos.items() if k in arm_motors}
gripper_goal = {k: v for k, v in goal_pos.items() if k in gripper_motors}
# Cap arm goal position when too far away from present position
if self.config.max_relative_target is not None and arm_goal:
present_pos = self.bus.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal.items()}
arm_goal = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# TODO(Steven, Pepijn): Refactor writing
# Motor name to index mapping for gains
motor_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
# Arm: batch MIT control (Damiao)
if arm_goal:
arm_motor_names = list(self.bus.motors.keys())
commands = {}
for motor_name, position_degrees in arm_goal.items():
idx = arm_motor_names.index(motor_name) if motor_name in arm_motor_names else 0
if custom_kp is not None and motor_name in custom_kp:
kp = custom_kp[motor_name]
else:
kp = (
self.config.position_kp[idx]
if isinstance(self.config.position_kp, list)
else self.config.position_kp
)
if custom_kd is not None and motor_name in custom_kd:
kd = custom_kd[motor_name]
else:
kd = (
self.config.position_kd[idx]
if isinstance(self.config.position_kd, list)
else self.config.position_kd
)
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
self.bus._mit_control_batch(commands)
# Use batch MIT control for arm (sends all commands, then collects responses)
commands = {}
for motor_name, position_degrees in goal_pos.items():
idx = motor_index.get(motor_name, 0)
# Use custom gains if provided, otherwise use config defaults
if custom_kp is not None and motor_name in custom_kp:
kp = custom_kp[motor_name]
else:
kp = (
self.config.position_kp[idx]
if isinstance(self.config.position_kp, list)
else self.config.position_kp
)
if custom_kd is not None and motor_name in custom_kd:
kd = custom_kd[motor_name]
else:
kd = (
self.config.position_kd[idx]
if isinstance(self.config.position_kd, list)
else self.config.position_kd
)
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
self.bus._mit_control_batch(commands)
# Gripper: position control (Feetech)
if gripper_goal:
self.gripper_bus.sync_write("Goal_Position", gripper_goal)
goal_pos.update(arm_goal)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
"""Disconnect from robot."""
# Disconnect CAN bus
self.bus.disconnect(self.config.disable_torque_on_disconnect)
# Disconnect cameras
self.gripper_bus.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -0,0 +1,630 @@
<?xml version='1.0' encoding='utf-8'?>
<robot name="openarm">
<link name="world" />
<joint name="openarm_body_world_joint" type="fixed">
<parent link="world" />
<child link="openarm_body_link0" />
<origin rpy="0 0 0" xyz="0 0 0" />
</joint>
<link name="openarm_body_link0">
<visual name="openarm_body_link0_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/body/v10/visual/body_link0.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_body_link0_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/body/v10/collision/body_link0_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<mass value="13.89" />
<inertia ixx="1.653" ixy="0.0" ixz="0.0" iyy="1.653" iyz="0.0" izz="0.051" />
</inertial>
</link>
<joint name="openarm_left_openarm_body_link0_joint" type="fixed">
<parent link="openarm_body_link0" />
<child link="openarm_left_link0" />
<origin rpy="-1.5708 0 0" xyz="0.0 0.031 0.698" />
</joint>
<link name="openarm_left_link0">
<visual name="openarm_left_link0_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link0_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 -0.0001580207020448382 0.03076860287587199" />
<mass value="1.1432284943239561" />
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
</inertial>
</link>
<link name="openarm_left_link1">
<visual name="openarm_left_link1_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link1_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 -3.319987657026362e-05 0.05395284380736254" />
<mass value="1.1416684646202298" />
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
</inertial>
</link>
<joint name="openarm_left_joint1" type="revolute">
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
<parent link="openarm_left_link0" />
<child link="openarm_left_link1" />
<axis xyz="0 0 1" />
<limit effort="40" lower="-3.490659" upper="1.3962629999999998" velocity="16.754666" />
</joint>
<link name="openarm_left_link2">
<visual name="openarm_left_link2_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link2_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 2.0145102027597523e-08 0.03256649300522363" />
<mass value="0.2775092746011571" />
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
</inertial>
</link>
<joint name="openarm_left_joint2" type="revolute">
<origin rpy="-1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
<parent link="openarm_left_link1" />
<child link="openarm_left_link2" />
<axis xyz="-1 0 0" />
<limit effort="40" lower="-3.3161253267948965" upper="0.17453267320510335" velocity="16.754666" />
</joint>
<link name="openarm_left_link3">
<visual name="openarm_left_link3_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link3_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 -0.0005549085042607548 0.09047470545721961" />
<mass value="1.073863338202347" />
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
</inertial>
</link>
<joint name="openarm_left_joint3" type="revolute">
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
<parent link="openarm_left_link2" />
<child link="openarm_left_link3" />
<axis xyz="0 0 1" />
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
</joint>
<link name="openarm_left_link4">
<visual name="openarm_left_link4_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link4_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
<mass value="0.6348534566833373" />
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
</inertial>
</link>
<joint name="openarm_left_joint4" type="revolute">
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
<parent link="openarm_left_link3" />
<child link="openarm_left_link4" />
<axis xyz="0 1 0" />
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
</joint>
<link name="openarm_left_link5">
<visual name="openarm_left_link5_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link5_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 -0.0008866902457326625 0.043079803024980934" />
<mass value="0.6156588026168502" />
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
</inertial>
</link>
<joint name="openarm_left_joint5" type="revolute">
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
<parent link="openarm_left_link4" />
<child link="openarm_left_link5" />
<axis xyz="0 0 1" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<link name="openarm_left_link6">
<visual name="openarm_left_link6_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link6_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 -0.00033230528343419053 -9.498374522309838e-05" />
<mass value="0.475202773187987" />
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
</inertial>
</link>
<joint name="openarm_left_joint6" type="revolute">
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
<parent link="openarm_left_link5" />
<child link="openarm_left_link6" />
<axis xyz="1 0 0" />
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
</joint>
<link name="openarm_left_link7">
<visual name="openarm_left_link7_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link7_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 -0.01266175250761268 0.06951945409987448" />
<mass value="0.4659771327380578" />
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
</inertial>
</link>
<joint name="openarm_left_joint7" type="revolute">
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
<parent link="openarm_left_link6" />
<child link="openarm_left_link7" />
<axis xyz="0 -1 0" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<joint name="openarm_right_openarm_body_link0_joint" type="fixed">
<parent link="openarm_body_link0" />
<child link="openarm_right_link0" />
<origin rpy="1.5708 0 0" xyz="0.0 -0.031 0.698" />
</joint>
<link name="openarm_right_link0">
<visual name="openarm_right_link0_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link0_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 0.0001580207020448382 0.03076860287587199" />
<mass value="1.1432284943239561" />
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
</inertial>
</link>
<link name="openarm_right_link1">
<visual name="openarm_right_link1_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link1_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 3.319987657026362e-05 0.05395284380736254" />
<mass value="1.1416684646202298" />
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
</inertial>
</link>
<joint name="openarm_right_joint1" type="revolute">
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
<parent link="openarm_right_link0" />
<child link="openarm_right_link1" />
<axis xyz="0 0 1" />
<limit effort="40" lower="-1.396263" upper="3.490659" velocity="16.754666" />
</joint>
<link name="openarm_right_link2">
<visual name="openarm_right_link2_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link2_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 -2.0145102027597523e-08 0.03256649300522363" />
<mass value="0.2775092746011571" />
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
</inertial>
</link>
<joint name="openarm_right_joint2" type="revolute">
<origin rpy="1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
<parent link="openarm_right_link1" />
<child link="openarm_right_link2" />
<axis xyz="-1 0 0" />
<limit effort="40" lower="-0.17453267320510335" upper="3.3161253267948965" velocity="16.754666" />
</joint>
<link name="openarm_right_link3">
<visual name="openarm_right_link3_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link3_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 0.0005549085042607548 0.09047470545721961" />
<mass value="1.073863338202347" />
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
</inertial>
</link>
<joint name="openarm_right_joint3" type="revolute">
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
<parent link="openarm_right_link2" />
<child link="openarm_right_link3" />
<axis xyz="0 0 1" />
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
</joint>
<link name="openarm_right_link4">
<visual name="openarm_right_link4_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link4_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
<mass value="0.6348534566833373" />
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
</inertial>
</link>
<joint name="openarm_right_joint4" type="revolute">
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
<parent link="openarm_right_link3" />
<child link="openarm_right_link4" />
<axis xyz="0 1 0" />
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
</joint>
<link name="openarm_right_link5">
<visual name="openarm_right_link5_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link5_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 0.0008866902457326625 0.043079803024980934" />
<mass value="0.6156588026168502" />
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
</inertial>
</link>
<joint name="openarm_right_joint5" type="revolute">
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
<parent link="openarm_right_link4" />
<child link="openarm_right_link5" />
<axis xyz="0 0 1" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<link name="openarm_right_link6">
<visual name="openarm_right_link6_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link6_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 0.00033230528343419053 -9.498374522309838e-05" />
<mass value="0.475202773187987" />
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
</inertial>
</link>
<joint name="openarm_right_joint6" type="revolute">
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
<parent link="openarm_right_link5" />
<child link="openarm_right_link6" />
<axis xyz="1 0 0" />
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
</joint>
<link name="openarm_right_link7">
<visual name="openarm_right_link7_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link7_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 0.01266175250761268 0.06951945409987448" />
<mass value="0.4659771327380578" />
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
</inertial>
</link>
<joint name="openarm_right_joint7" type="revolute">
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
<parent link="openarm_right_link6" />
<child link="openarm_right_link7" />
<axis xyz="0 1 0" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<link name="openarm_left_hand">
<visual name="openarm_left_hand_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_hand_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
<mass value="0.35" />
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
</inertial>
</link>
<joint name="left_openarm_hand_joint" type="fixed">
<parent link="openarm_left_link7" />
<child link="openarm_left_hand" />
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
</joint>
<link name="openarm_left_hand_tcp">
<inertial>
<origin xyz="0 0 0" rpy="0 0 0" />
<mass value="0.001" />
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
</inertial>
</link>
<joint name="openarm_left_hand_tcp_joint" type="fixed">
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
<parent link="openarm_left_hand" />
<child link="openarm_left_hand_tcp" />
</joint>
<link name="openarm_left_left_finger">
<visual name="openarm_left_left_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_left_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<link name="openarm_left_right_finger">
<visual name="openarm_left_right_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_right_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<joint name="openarm_left_finger_joint1" type="prismatic">
<parent link="openarm_left_hand" />
<child link="openarm_left_right_finger" />
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
<axis xyz="0 -1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
</joint>
<joint name="openarm_left_finger_joint2" type="prismatic">
<parent link="openarm_left_hand" />
<child link="openarm_left_left_finger" />
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
<axis xyz="0 1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
<mimic joint="openarm_left_finger_joint1" />
</joint>
<link name="openarm_right_hand">
<visual name="openarm_right_hand_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_hand_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
<mass value="0.35" />
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
</inertial>
</link>
<link name="openarm_right_ee_target">
<inertial>
<origin xyz="0 0 0" rpy="0 0 0" />
<mass value="0.001" />
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
</inertial>
</link>
<joint name="openarm_right_ee_target_joint" type="fixed">
<parent link="openarm_right_link7" />
<child link="openarm_right_ee_target" />
<origin rpy="0 0 0" xyz="0 0.0 0.07" />
</joint>
<joint name="right_openarm_hand_joint" type="fixed">
<parent link="openarm_right_link7" />
<child link="openarm_right_hand" />
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
</joint>
<link name="openarm_right_hand_tcp">
<inertial>
<origin xyz="0 0 0" rpy="0 0 0" />
<mass value="0.001" />
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
</inertial>
</link>
<joint name="openarm_right_hand_tcp_joint" type="fixed">
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
<parent link="openarm_right_hand" />
<child link="openarm_right_hand_tcp" />
</joint>
<link name="openarm_right_left_finger">
<visual name="openarm_right_left_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_left_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<link name="openarm_right_right_finger">
<visual name="openarm_right_right_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_right_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<joint name="openarm_right_finger_joint1" type="prismatic">
<parent link="openarm_right_hand" />
<child link="openarm_right_right_finger" />
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
<axis xyz="0 -1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
</joint>
<joint name="openarm_right_finger_joint2" type="prismatic">
<parent link="openarm_right_hand" />
<child link="openarm_right_left_finger" />
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
<axis xyz="0 1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
<mimic joint="openarm_right_finger_joint1" />
</joint>
</robot>

View File

@@ -0,0 +1,408 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>Dataset Replay — EE Frame Viewer</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body { background: #0d1117; overflow: hidden; font-family: 'JetBrains Mono', monospace; color: #c9d1d9; }
canvas { display: block; }
#panel {
position: absolute; top: 14px; left: 14px;
background: rgba(13,17,23,0.92); border: 1px solid #30363d;
border-radius: 10px; padding: 16px 20px; z-index: 10;
width: 340px; backdrop-filter: blur(8px);
}
#panel h2 { font-size: 14px; color: #58a6ff; margin-bottom: 10px; letter-spacing: 0.5px; }
.row { display: flex; align-items: center; gap: 8px; margin: 6px 0; font-size: 12px; }
.row label { width: 70px; color: #8b949e; flex-shrink: 0; }
.row .val { color: #f0f6fc; font-variant-numeric: tabular-nums; }
#transport {
margin-top: 12px; display: flex; align-items: center; gap: 8px;
}
#transport button {
background: #21262d; color: #c9d1d9; border: 1px solid #30363d;
padding: 6px 14px; border-radius: 6px; cursor: pointer;
font-family: inherit; font-size: 12px; transition: background 0.15s;
}
#transport button:hover { background: #30363d; }
#transport button.active { background: #1f6feb; border-color: #1f6feb; color: #fff; }
#scrubber {
width: 100%; margin-top: 8px;
-webkit-appearance: none; appearance: none;
height: 6px; border-radius: 3px; background: #21262d; outline: none;
}
#scrubber::-webkit-slider-thumb {
-webkit-appearance: none; width: 14px; height: 14px;
border-radius: 50%; background: #58a6ff; cursor: pointer;
}
#speed-ctrl { margin-top: 6px; }
#speed-ctrl select {
background: #21262d; color: #c9d1d9; border: 1px solid #30363d;
padding: 4px 8px; border-radius: 4px; font-family: inherit; font-size: 11px;
}
#frame-counter {
font-size: 11px; color: #8b949e; margin-top: 6px;
font-variant-numeric: tabular-nums;
}
.legend { display: flex; align-items: center; gap: 6px; margin: 3px 0; font-size: 11px; }
.dot { width: 10px; height: 10px; border-radius: 50%; display: inline-block; }
</style>
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&display=swap" rel="stylesheet">
</head>
<body>
<div id="panel">
<h2>DATASET REPLAY — EE FRAME</h2>
<div style="font-size:11px;color:#8b949e;margin-bottom:8px;">glannuzel/grabette-dataset · episode 0</div>
<div class="legend"><span class="dot" style="background:#ff6b6b"></span> EE target (dataset)</div>
<div class="legend"><span class="dot" style="background:#ffd43b"></span> Trajectory (past)</div>
<div class="legend"><span class="dot" style="background:#30363d"></span> Trajectory (future)</div>
<div class="row"><label>x</label><span class="val" id="v-x"></span></div>
<div class="row"><label>y</label><span class="val" id="v-y"></span></div>
<div class="row"><label>z</label><span class="val" id="v-z"></span></div>
<div class="row"><label>ax</label><span class="val" id="v-ax"></span></div>
<div class="row"><label>ay</label><span class="val" id="v-ay"></span></div>
<div class="row"><label>az</label><span class="val" id="v-az"></span></div>
<div class="row"><label>gripper</label><span class="val" id="v-grip"></span></div>
<div id="transport">
<button id="btn-play" onclick="togglePlay()">▶ Play</button>
<button onclick="stepFrame(-1)"></button>
<button onclick="stepFrame(1)"></button>
<button onclick="resetPlay()"></button>
</div>
<input type="range" id="scrubber" min="0" max="1" value="0" step="1" />
<div id="speed-ctrl">
<label style="font-size:11px;color:#8b949e;">Speed:</label>
<select id="speed-select" onchange="setSpeed(this.value)">
<option value="0.25">0.25×</option>
<option value="0.5">0.5×</option>
<option value="1" selected>1×</option>
<option value="2">2×</option>
<option value="4">4×</option>
</select>
</div>
<div id="frame-counter">Frame 0 / 0 · 0.00s</div>
</div>
<script type="importmap">
{
"imports": {
"three": "https://cdn.jsdelivr.net/npm/three@0.169.0/build/three.module.js",
"three/examples/jsm/": "https://cdn.jsdelivr.net/npm/three@0.169.0/examples/jsm/"
}
}
</script>
<script type="module">
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
import { STLLoader } from 'three/examples/jsm/loaders/STLLoader.js';
let trajectory = null;
let currentFrame = 0;
let playing = false;
let speed = 1.0;
let lastTime = 0;
let accumulator = 0;
// Anchor: EE tip world position at zero-joint pose (in Y-up Three.js space)
const eeAnchor = new THREE.Vector3();
// Z-up → Y-up rotation (same as robotGroup): -90° around X
const zUpToYUp = new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(1, 0, 0), -Math.PI / 2);
const scene = new THREE.Scene();
scene.background = new THREE.Color(0x0d1117);
const camera = new THREE.PerspectiveCamera(50, window.innerWidth / window.innerHeight, 0.01, 100);
const renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(window.innerWidth, window.innerHeight);
renderer.setPixelRatio(window.devicePixelRatio);
renderer.shadowMap.enabled = true;
document.body.appendChild(renderer.domElement);
const controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.08;
scene.add(new THREE.AmbientLight(0xffffff, 0.8));
const dirLight = new THREE.DirectionalLight(0xffffff, 1.4);
dirLight.position.set(2, 4, 3);
scene.add(dirLight);
scene.add(new THREE.DirectionalLight(0x8899cc, 0.6).translateX(-2).translateY(1).translateZ(-3));
scene.add(new THREE.DirectionalLight(0xffffff, 0.5).translateY(-1).translateZ(2));
const grid = new THREE.GridHelper(2, 20, 0x21262d, 0x161b22);
scene.add(grid);
scene.add(new THREE.AxesHelper(0.15));
// EE marker
const eeMarker = new THREE.Mesh(
new THREE.SphereGeometry(0.012, 20, 20),
new THREE.MeshStandardMaterial({ color: 0xff6b6b, emissive: 0xff6b6b, emissiveIntensity: 0.7 })
);
scene.add(eeMarker);
eeMarker.add(new THREE.AxesHelper(0.06));
// Trajectory lines
const MAX_POINTS = 2000;
const pastGeo = new THREE.BufferGeometry();
pastGeo.setAttribute('position', new THREE.Float32BufferAttribute(new Float32Array(MAX_POINTS * 3), 3));
const pastLine = new THREE.Line(pastGeo, new THREE.LineBasicMaterial({ color: 0xffd43b, linewidth: 2 }));
scene.add(pastLine);
const futureGeo = new THREE.BufferGeometry();
futureGeo.setAttribute('position', new THREE.Float32BufferAttribute(new Float32Array(MAX_POINTS * 3), 3));
const futureLine = new THREE.Line(futureGeo, new THREE.LineBasicMaterial({ color: 0x30363d, linewidth: 1 }));
scene.add(futureLine);
// URDF
const stlLoader = new STLLoader();
const robotGroup = new THREE.Group();
// URDF is Z-up; Three.js is Y-up → rotate -90° around X
robotGroup.rotation.x = -Math.PI / 2;
scene.add(robotGroup);
let urdfLinks = {};
function rotvecToQuat(ax, ay, az) {
const angle = Math.sqrt(ax * ax + ay * ay + az * az);
if (angle < 1e-8) return new THREE.Quaternion();
return new THREE.Quaternion().setFromAxisAngle(
new THREE.Vector3(ax / angle, ay / angle, az / angle), angle
);
}
async function loadURDF() {
const resp = await fetch('./openarm_bimanual_pybullet.urdf');
const text = await resp.text();
const xml = new DOMParser().parseFromString(text, 'text/xml');
const links = {};
for (const linkEl of xml.querySelectorAll('link')) {
const name = linkEl.getAttribute('name');
const group = new THREE.Group();
group.name = name;
const visual = linkEl.querySelector('visual');
if (visual) {
const meshEl = visual.querySelector('mesh');
const originEl = visual.querySelector('origin');
if (meshEl) {
const filename = meshEl.getAttribute('filename');
const scaleStr = meshEl.getAttribute('scale');
const sc = scaleStr ? scaleStr.split(' ').map(Number) : [1, 1, 1];
let xyz = [0, 0, 0];
if (originEl && originEl.getAttribute('xyz'))
xyz = originEl.getAttribute('xyz').split(' ').map(Number);
if (filename.endsWith('.stl')) {
try {
const geo = await new Promise((res, rej) =>
stlLoader.load(filename, res, undefined, rej));
const mesh = new THREE.Mesh(geo, new THREE.MeshStandardMaterial({
color: 0x8899bb, metalness: 0.3, roughness: 0.5,
}));
mesh.scale.set(sc[0], sc[1], sc[2]);
mesh.position.set(xyz[0], xyz[1], xyz[2]);
group.add(mesh);
} catch (e) { /* skip missing mesh */ }
}
}
}
links[name] = group;
}
const rootLinks = new Set(Object.keys(links));
for (const jointEl of xml.querySelectorAll('joint')) {
const parentName = jointEl.querySelector('parent').getAttribute('link');
const childName = jointEl.querySelector('child').getAttribute('link');
rootLinks.delete(childName);
const originEl = jointEl.querySelector('origin');
let xyz = [0, 0, 0], rpy = [0, 0, 0];
if (originEl) {
if (originEl.getAttribute('xyz')) xyz = originEl.getAttribute('xyz').split(' ').map(Number);
if (originEl.getAttribute('rpy')) rpy = originEl.getAttribute('rpy').split(' ').map(Number);
}
const parent = links[parentName];
const child = links[childName];
if (!parent || !child) continue;
child.position.set(xyz[0], xyz[1], xyz[2]);
if (rpy[0] || rpy[1] || rpy[2])
child.rotation.set(rpy[0], rpy[1], rpy[2], 'XYZ');
parent.add(child);
}
for (const n of rootLinks)
if (links[n]) robotGroup.add(links[n]);
// EE target marker on the URDF
const eeTargetLink = links['openarm_right_ee_target'];
if (eeTargetLink) {
eeTargetLink.add(new THREE.Mesh(
new THREE.TorusGeometry(0.02, 0.002, 8, 32),
new THREE.MeshStandardMaterial({ color: 0xffaa00, emissive: 0xffaa00, emissiveIntensity: 0.5 })
));
eeTargetLink.add(new THREE.AxesHelper(0.05));
}
urdfLinks = links;
}
async function loadTrajectory() {
const resp = await fetch('./trajectory_ep0.json');
trajectory = await resp.json();
document.getElementById('scrubber').max = trajectory.num_frames - 1;
document.getElementById('scrubber').value = 0;
}
function computeOffset() {
if (!trajectory || !urdfLinks['openarm_right_ee_target']) return;
robotGroup.updateMatrixWorld(true);
const eeLink = urdfLinks['openarm_right_ee_target'];
eeLink.getWorldPosition(eeAnchor);
controls.target.copy(eeAnchor);
camera.position.set(eeAnchor.x + 0.8, eeAnchor.y + 0.3, eeAnchor.z + 0.0);
controls.update();
updateFrame(0);
}
function mapFramePos(f) {
const f0 = trajectory.frames[0];
const delta = new THREE.Vector3(f.x - f0.x, f.y - f0.y, f.z - f0.z);
delta.applyQuaternion(zUpToYUp);
return delta.add(eeAnchor);
}
function updateFrame(idx) {
if (!trajectory) return;
currentFrame = Math.max(0, Math.min(idx, trajectory.num_frames - 1));
const f = trajectory.frames[currentFrame];
const pos = mapFramePos(f);
eeMarker.position.copy(pos);
// Orientation: rotate the dataset axis-angle into Y-up space
const q = rotvecToQuat(f.ax, f.ay, f.az);
eeMarker.quaternion.copy(zUpToYUp).multiply(q);
// Past trajectory
const pastArr = pastGeo.attributes.position.array;
let pi = 0;
for (let i = 0; i <= currentFrame && i < MAX_POINTS; i++) {
const p = mapFramePos(trajectory.frames[i]);
pastArr[pi++] = p.x; pastArr[pi++] = p.y; pastArr[pi++] = p.z;
}
pastGeo.setDrawRange(0, Math.min(currentFrame + 1, MAX_POINTS));
pastGeo.attributes.position.needsUpdate = true;
// Future trajectory
const futArr = futureGeo.attributes.position.array;
let fi = 0;
for (let i = currentFrame; i < trajectory.num_frames && (i - currentFrame) < MAX_POINTS; i++) {
const p = mapFramePos(trajectory.frames[i]);
futArr[fi++] = p.x; futArr[fi++] = p.y; futArr[fi++] = p.z;
}
futureGeo.setDrawRange(0, Math.min(trajectory.num_frames - currentFrame, MAX_POINTS));
futureGeo.attributes.position.needsUpdate = true;
// UI
document.getElementById('v-x').textContent = pos.x.toFixed(4);
document.getElementById('v-y').textContent = pos.y.toFixed(4);
document.getElementById('v-z').textContent = pos.z.toFixed(4);
document.getElementById('v-ax').textContent = f.ax.toFixed(4);
document.getElementById('v-ay').textContent = f.ay.toFixed(4);
document.getElementById('v-az').textContent = f.az.toFixed(4);
document.getElementById('v-grip').textContent =
`p=${f.proximal.toFixed(2)} d=${f.distal.toFixed(2)}`;
document.getElementById('scrubber').value = currentFrame;
const timeS = (currentFrame / trajectory.fps).toFixed(2);
document.getElementById('frame-counter').textContent =
`Frame ${currentFrame} / ${trajectory.num_frames - 1} · ${timeS}s`;
}
// Playback controls
window.togglePlay = function() {
playing = !playing;
const btn = document.getElementById('btn-play');
btn.textContent = playing ? '⏸ Pause' : '▶ Play';
btn.classList.toggle('active', playing);
if (playing) { lastTime = performance.now(); accumulator = 0; }
};
window.stepFrame = function(delta) {
playing = false;
document.getElementById('btn-play').textContent = '▶ Play';
document.getElementById('btn-play').classList.remove('active');
updateFrame(currentFrame + delta);
};
window.resetPlay = function() {
playing = false;
document.getElementById('btn-play').textContent = '▶ Play';
document.getElementById('btn-play').classList.remove('active');
updateFrame(0);
};
window.setSpeed = function(v) { speed = parseFloat(v); };
document.getElementById('scrubber').addEventListener('input', (e) => {
updateFrame(parseInt(e.target.value));
});
window.addEventListener('resize', () => {
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
});
function animate(now) {
requestAnimationFrame(animate);
controls.update();
if (playing && trajectory) {
const dt = (now - lastTime) / 1000;
lastTime = now;
accumulator += dt * speed;
const frameDuration = 1.0 / trajectory.fps;
while (accumulator >= frameDuration) {
accumulator -= frameDuration;
if (currentFrame < trajectory.num_frames - 1) {
updateFrame(currentFrame + 1);
} else {
playing = false;
document.getElementById('btn-play').textContent = '▶ Play';
document.getElementById('btn-play').classList.remove('active');
break;
}
}
}
renderer.render(scene, camera);
}
requestAnimationFrame(animate);
Promise.all([loadURDF(), loadTrajectory()])
.then(() => computeOffset())
.catch(err => console.error(err));
</script>
</body>
</html>

View File

@@ -0,0 +1,311 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>OpenArm URDF Viewer</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body { background: #1a1a2e; overflow: hidden; font-family: 'IBM Plex Mono', monospace; }
canvas { display: block; }
#info {
position: absolute; top: 16px; left: 16px;
color: #e0e0e0; font-size: 13px; line-height: 1.6;
background: rgba(0,0,0,0.7); padding: 14px 18px; border-radius: 8px;
border: 1px solid #333; max-width: 340px; z-index: 10;
}
#info h2 { font-size: 15px; color: #fff; margin-bottom: 8px; }
.legend { display: flex; align-items: center; gap: 8px; margin: 4px 0; }
.dot { width: 12px; height: 12px; border-radius: 50%; display: inline-block; flex-shrink: 0; }
.dot-red { background: #ff4444; }
.dot-green { background: #44ff44; }
.dot-blue { background: #4488ff; }
#frame-select { margin-top: 10px; }
#frame-select button {
background: #333; color: #e0e0e0; border: 1px solid #555;
padding: 6px 10px; margin: 2px; border-radius: 4px; cursor: pointer;
font-family: inherit; font-size: 12px;
}
#frame-select button:hover { background: #555; }
#frame-select button.active { background: #4488ff; color: #fff; border-color: #4488ff; }
#status { margin-top: 8px; font-size: 11px; color: #888; }
</style>
<link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&display=swap" rel="stylesheet">
</head>
<body>
<div id="info">
<h2>OpenArm Right Arm — EE Frame Options</h2>
<div class="legend"><span class="dot dot-red"></span> openarm_right_link7 (wrist output)</div>
<div class="legend"><span class="dot" style="background:#ffaa00"></span> openarm_right_ee_target (+5cm)</div>
<div class="legend"><span class="dot dot-green"></span> openarm_right_hand (+10cm)</div>
<div class="legend"><span class="dot dot-blue"></span> openarm_right_hand_tcp (+18cm)</div>
<div id="frame-select">
<button onclick="focusFrame('link7')" class="active">link7</button>
<button onclick="focusFrame('ee_target')">ee_target</button>
<button onclick="focusFrame('hand')">hand</button>
<button onclick="focusFrame('tcp')">hand_tcp</button>
</div>
<div id="status">Loading URDF...</div>
<p style="margin-top:8px;font-size:11px;color:#888;">Drag to orbit · Scroll to zoom · Right-drag to pan</p>
</div>
<script type="importmap">
{
"imports": {
"three": "https://cdn.jsdelivr.net/npm/three@0.169.0/build/three.module.js",
"three/examples/jsm/": "https://cdn.jsdelivr.net/npm/three@0.169.0/examples/jsm/"
}
}
</script>
<script type="module">
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
import { STLLoader } from 'three/examples/jsm/loaders/STLLoader.js';
const statusEl = document.getElementById('status');
const scene = new THREE.Scene();
scene.background = new THREE.Color(0x1a1a2e);
const camera = new THREE.PerspectiveCamera(50, window.innerWidth / window.innerHeight, 0.01, 100);
camera.position.set(0.8, 1.0, 1.8);
const renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(window.innerWidth, window.innerHeight);
renderer.setPixelRatio(window.devicePixelRatio);
renderer.shadowMap.enabled = true;
document.body.appendChild(renderer.domElement);
const controls = new OrbitControls(camera, renderer.domElement);
controls.target.set(0, 0, 0.9);
controls.enableDamping = true;
controls.dampingFactor = 0.08;
controls.update();
// Lighting
scene.add(new THREE.AmbientLight(0xffffff, 0.6));
const dirLight = new THREE.DirectionalLight(0xffffff, 1.2);
dirLight.position.set(3, 5, 4);
scene.add(dirLight);
scene.add(new THREE.DirectionalLight(0x8888ff, 0.4).translateX(-2).translateY(1).translateZ(-3));
// Ground grid
scene.add(new THREE.GridHelper(4, 40, 0x333355, 0x222244));
scene.add(new THREE.AxesHelper(0.3));
// Parse URDF manually — build the kinematic tree and load STL meshes
const stlLoader = new STLLoader();
const robotGroup = new THREE.Group();
scene.add(robotGroup);
async function loadURDF() {
const resp = await fetch('./openarm_bimanual_pybullet.urdf');
const text = await resp.text();
const parser = new DOMParser();
const xml = parser.parseFromString(text, 'text/xml');
// Parse links and joints
const links = {};
const joints = [];
for (const linkEl of xml.querySelectorAll('link')) {
const name = linkEl.getAttribute('name');
const group = new THREE.Group();
group.name = name;
// Try to load visual mesh
const visual = linkEl.querySelector('visual');
if (visual) {
const meshEl = visual.querySelector('mesh');
const originEl = visual.querySelector('origin');
if (meshEl) {
const filename = meshEl.getAttribute('filename');
const scaleStr = meshEl.getAttribute('scale');
const scale = scaleStr ? scaleStr.split(' ').map(Number) : [1, 1, 1];
let xyz = [0, 0, 0];
if (originEl && originEl.getAttribute('xyz')) {
xyz = originEl.getAttribute('xyz').split(' ').map(Number);
}
if (filename.endsWith('.stl')) {
try {
const geo = await new Promise((resolve, reject) => {
stlLoader.load(filename, resolve, undefined, reject);
});
const mat = new THREE.MeshStandardMaterial({
color: 0x6688aa,
metalness: 0.3,
roughness: 0.6,
transparent: true,
opacity: 0.7,
});
const mesh = new THREE.Mesh(geo, mat);
mesh.scale.set(scale[0], scale[1], scale[2]);
mesh.position.set(xyz[0], xyz[1], xyz[2]);
group.add(mesh);
} catch (e) {
// Mesh file not found, skip
}
}
}
}
links[name] = group;
}
// Parse joints and build hierarchy
for (const jointEl of xml.querySelectorAll('joint')) {
const name = jointEl.getAttribute('name');
const type = jointEl.getAttribute('type');
const parentName = jointEl.querySelector('parent').getAttribute('link');
const childName = jointEl.querySelector('child').getAttribute('link');
const originEl = jointEl.querySelector('origin');
let xyz = [0, 0, 0];
let rpy = [0, 0, 0];
if (originEl) {
if (originEl.getAttribute('xyz')) xyz = originEl.getAttribute('xyz').split(' ').map(Number);
if (originEl.getAttribute('rpy')) rpy = originEl.getAttribute('rpy').split(' ').map(Number);
}
joints.push({ name, type, parentName, childName, xyz, rpy });
}
// Build tree
const rootLinks = new Set(Object.keys(links));
for (const j of joints) {
rootLinks.delete(j.childName);
}
for (const j of joints) {
const parent = links[j.parentName];
const child = links[j.childName];
if (parent && child) {
child.position.set(j.xyz[0], j.xyz[1], j.xyz[2]);
if (j.rpy[0] || j.rpy[1] || j.rpy[2]) {
child.rotation.set(j.rpy[0], j.rpy[1], j.rpy[2], 'XYZ');
}
parent.add(child);
}
}
// Add root links to scene
for (const name of rootLinks) {
if (links[name]) robotGroup.add(links[name]);
}
// Place markers at the three EE frame candidates
const targets = {
link7: 'openarm_right_link7',
ee_target: 'openarm_right_ee_target',
hand: 'openarm_right_hand',
tcp: 'openarm_right_hand_tcp',
};
const colors = { link7: 0xff4444, ee_target: 0xffaa00, hand: 0x44ff44, tcp: 0x4488ff };
const labels = { link7: 'link7', ee_target: 'ee_target', hand: 'hand', tcp: 'hand_tcp' };
for (const [key, linkName] of Object.entries(targets)) {
const link = links[linkName];
if (!link) continue;
// Marker sphere
const sphere = new THREE.Mesh(
new THREE.SphereGeometry(0.018, 16, 16),
new THREE.MeshStandardMaterial({ color: colors[key], emissive: colors[key], emissiveIntensity: 0.6 })
);
link.add(sphere);
// Ring around sphere for visibility
const ring = new THREE.Mesh(
new THREE.TorusGeometry(0.03, 0.003, 8, 32),
new THREE.MeshStandardMaterial({ color: colors[key], emissive: colors[key], emissiveIntensity: 0.4 })
);
link.add(ring);
// Axes helper
link.add(new THREE.AxesHelper(0.08));
// Sprite label
const canvas = document.createElement('canvas');
canvas.width = 512; canvas.height = 80;
const ctx = canvas.getContext('2d');
ctx.font = 'bold 36px IBM Plex Mono, monospace';
ctx.fillStyle = '#' + colors[key].toString(16).padStart(6, '0');
ctx.fillText(labels[key], 4, 50);
const tex = new THREE.CanvasTexture(canvas);
const sprite = new THREE.Sprite(new THREE.SpriteMaterial({ map: tex, depthTest: false }));
sprite.scale.set(0.3, 0.05, 1);
sprite.position.set(0.06, 0.0, 0.03);
link.add(sprite);
}
// Dashed lines between markers (in world space)
robotGroup.updateMatrixWorld(true);
const positions = {};
for (const [key, linkName] of Object.entries(targets)) {
const link = links[linkName];
if (link) {
const wp = new THREE.Vector3();
link.getWorldPosition(wp);
positions[key] = wp;
}
}
function addDashedLine(from, to) {
const geo = new THREE.BufferGeometry().setFromPoints([from, to]);
const mat = new THREE.LineDashedMaterial({ color: 0xaaaaaa, dashSize: 0.012, gapSize: 0.008 });
const line = new THREE.Line(geo, mat);
line.computeLineDistances();
scene.add(line);
}
if (positions.link7 && positions.hand) addDashedLine(positions.link7, positions.hand);
if (positions.hand && positions.tcp) addDashedLine(positions.hand, positions.tcp);
// Store for focus buttons
window._framePositions = positions;
window._links = links;
window._targets = targets;
// Focus on the hand area
if (positions.hand) {
controls.target.copy(positions.hand);
camera.position.set(positions.hand.x + 0.5, positions.hand.y + 0.4, positions.hand.z + 0.5);
controls.update();
}
const meshCount = robotGroup.children.length;
statusEl.textContent = `Loaded. Right arm chain visible with ${Object.keys(links).length} links.`;
}
window.focusFrame = function(key) {
const pos = window._framePositions?.[key];
if (!pos) return;
controls.target.copy(pos);
camera.position.set(pos.x + 0.35, pos.y + 0.25, pos.z + 0.35);
controls.update();
document.querySelectorAll('#frame-select button').forEach(b => b.classList.remove('active'));
event.target.classList.add('active');
};
window.addEventListener('resize', () => {
camera.aspect = window.innerWidth / window.innerHeight;
camera.updateProjectionMatrix();
renderer.setSize(window.innerWidth, window.innerHeight);
});
function animate() {
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
}
animate();
loadURDF().catch(err => {
statusEl.textContent = `Error: ${err.message}`;
console.error(err);
});
</script>
</body>
</html>

View File

@@ -255,19 +255,16 @@ class InverseKinematicsEEToJoints(RobotActionProcessorStep):
"""
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
This step translates a Cartesian command (position and orientation of the end-effector) into
the corresponding joint-space commands for each motor.
Attributes:
kinematics: The robot's kinematic model for inverse kinematics.
motor_names: A list of motor names for which to compute joint positions.
q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
motor_names: Arm joint names for IK computation.
gripper_names: Gripper joint name(s). ee.gripper_pos is written to all of them.
initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
If False, use the solution from the previous step.
"""
kinematics: RobotKinematics
motor_names: list[str]
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
initial_guess_current_joints: bool = True
@@ -278,63 +275,73 @@ class InverseKinematicsEEToJoints(RobotActionProcessorStep):
wx = action.pop("ee.wx")
wy = action.pop("ee.wy")
wz = action.pop("ee.wz")
gripper_pos = action.pop("ee.gripper_pos")
if None in (x, y, z, wx, wy, wz, gripper_pos):
raise ValueError(
"Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
)
ee_keys = [x, y, z, wx, wy, wz]
if self.gripper_names:
gripper_pos = action.pop("ee.gripper_pos")
ee_keys.append(gripper_pos)
if None in ee_keys:
raise ValueError("Missing required end-effector pose components in action")
observation = self.transition.get(TransitionKey.OBSERVATION).copy()
if observation is None:
raise ValueError("Joints observation is require for computing robot kinematics")
raise ValueError("Joints observation is required for computing robot kinematics")
q_raw = np.array(
[float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
[
float(v)
for k, v in observation.items()
if isinstance(k, str) and k.endswith(".pos") and k.removesuffix(".pos") in self.motor_names
],
dtype=float,
)
if q_raw is None:
raise ValueError("Joints observation is require for computing robot kinematics")
if self.initial_guess_current_joints: # Use current joints as initial guess
if self.initial_guess_current_joints:
self.q_curr = q_raw
else: # Use previous ik solution as initial guess
else:
if self.q_curr is None:
self.q_curr = q_raw
# Build desired 4x4 transform from pos + rotvec (twist)
t_des = np.eye(4, dtype=float)
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
t_des[:3, 3] = [x, y, z]
# Compute inverse kinematics
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
self.q_curr = q_target
# TODO: This is sentitive to order of motor_names = q_target mapping
for i, name in enumerate(self.motor_names):
if name != "gripper":
action[f"{name}.pos"] = float(q_target[i])
else:
action["gripper.pos"] = float(gripper_pos)
action[f"{name}.pos"] = float(q_target[i])
if self.gripper_names:
for gname in self.gripper_names:
action[f"{gname}.pos"] = float(gripper_pos)
# When gripper_names is empty, gripper keys (e.g. proximal.pos, distal.pos)
# are already in the action dict as absolute positions — left untouched.
return action
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
ee_feats = ["x", "y", "z", "wx", "wy", "wz"]
if self.gripper_names:
ee_feats.append("gripper_pos")
for feat in ee_feats:
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
for name in self.motor_names:
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
for name in self.gripper_names:
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features
def reset(self):
"""Resets the initial guess for the IK solver."""
self.q_curr = None
@@ -402,24 +409,39 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
def compute_forward_kinematics_joints_to_ee(
joints: dict[str, Any], kinematics: RobotKinematics, motor_names: list[str]
joints: dict[str, Any],
kinematics: RobotKinematics,
motor_names: list[str],
gripper_names: list[str] | None = None,
) -> dict[str, Any]:
if gripper_names is None:
gripper_names = ["gripper"]
motor_joint_values = [joints[f"{n}.pos"] for n in motor_names]
q = np.array(motor_joint_values, dtype=float)
t = kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
gripper_pos = joints["gripper.pos"]
for n in motor_names:
joints.pop(f"{n}.pos")
joints["ee.x"] = float(pos[0])
joints["ee.y"] = float(pos[1])
joints["ee.z"] = float(pos[2])
joints["ee.wx"] = float(tw[0])
joints["ee.wy"] = float(tw[1])
joints["ee.wz"] = float(tw[2])
joints["ee.gripper_pos"] = float(gripper_pos)
# When gripper_names is non-empty, fold them into ee.gripper_pos (e.g. SO100).
# When empty, gripper joints pass through as-is (absolute position control).
if gripper_names:
gripper_pos = joints[f"{gripper_names[0]}.pos"]
for n in gripper_names:
joints.pop(f"{n}.pos", None)
joints["ee.gripper_pos"] = float(gripper_pos)
return joints
@@ -429,27 +451,33 @@ class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep):
"""
Computes the end-effector pose from joint positions using forward kinematics (FK).
This step is typically used to add the robot's Cartesian pose to the observation space,
which can be useful for visualization or as an input to a policy.
Attributes:
kinematics: The robot's kinematic model.
motor_names: Arm joint names used for FK computation.
gripper_names: Gripper joint name(s) to fold into ee.gripper_pos.
Empty list means gripper joints pass through as absolute positions.
"""
kinematics: RobotKinematics
motor_names: list[str]
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
def observation(self, observation: RobotObservation) -> RobotObservation:
return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names)
return compute_forward_kinematics_joints_to_ee(
observation, self.kinematics, self.motor_names, self.gripper_names
)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We only use the ee pose in the dataset, so we don't need the joint positions
for n in self.motor_names:
features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None)
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
ee_keys = ["x", "y", "z", "wx", "wy", "wz"]
if self.gripper_names:
for n in self.gripper_names:
features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None)
ee_keys.append("gripper_pos")
for k in ee_keys:
features[PipelineFeatureType.OBSERVATION][f"ee.{k}"] = PolicyFeature(
type=FeatureType.STATE, shape=(1,)
)
@@ -462,27 +490,33 @@ class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep):
"""
Computes the end-effector pose from joint positions using forward kinematics (FK).
This step is typically used to add the robot's Cartesian pose to the observation space,
which can be useful for visualization or as an input to a policy.
Attributes:
kinematics: The robot's kinematic model.
motor_names: Arm joint names used for FK computation.
gripper_names: Gripper joint name(s) to fold into ee.gripper_pos.
Empty list means gripper joints pass through as absolute positions.
"""
kinematics: RobotKinematics
motor_names: list[str]
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
def action(self, action: RobotAction) -> RobotAction:
return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names)
return compute_forward_kinematics_joints_to_ee(
action, self.kinematics, self.motor_names, self.gripper_names
)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We only use the ee pose in the dataset, so we don't need the joint positions
for n in self.motor_names:
features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None)
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
ee_keys = ["x", "y", "z", "wx", "wy", "wz"]
if self.gripper_names:
for n in self.gripper_names:
features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None)
ee_keys.append("gripper_pos")
for k in ee_keys:
features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature(
type=FeatureType.STATE, shape=(1,)
)
@@ -494,13 +528,14 @@ class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep):
class ForwardKinematicsJointsToEE(ProcessorStep):
kinematics: RobotKinematics
motor_names: list[str]
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
def __post_init__(self):
self.joints_to_ee_action_processor = ForwardKinematicsJointsToEEAction(
kinematics=self.kinematics, motor_names=self.motor_names
kinematics=self.kinematics, motor_names=self.motor_names, gripper_names=self.gripper_names
)
self.joints_to_ee_observation_processor = ForwardKinematicsJointsToEEObservation(
kinematics=self.kinematics, motor_names=self.motor_names
kinematics=self.kinematics, motor_names=self.motor_names, gripper_names=self.gripper_names
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -524,13 +559,13 @@ class ForwardKinematicsJointsToEE(ProcessorStep):
@dataclass
class InverseKinematicsRLStep(ProcessorStep):
"""
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
This is modified from the InverseKinematicsEEToJoints step to be used in the RL pipeline.
IK step for the RL pipeline. Same logic as InverseKinematicsEEToJoints but
operates on EnvTransition directly and stores the IK solution.
"""
kinematics: RobotKinematics
motor_names: list[str]
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
initial_guess_current_joints: bool = True
@@ -538,7 +573,7 @@ class InverseKinematicsRLStep(ProcessorStep):
new_transition = dict(transition)
action = new_transition.get(TransitionKey.ACTION)
if action is None:
raise ValueError("Action is required for InverseKinematicsEEToJoints")
raise ValueError("Action is required for InverseKinematicsRLStep")
action = dict(action)
x = action.pop("ee.x")
@@ -547,45 +582,46 @@ class InverseKinematicsRLStep(ProcessorStep):
wx = action.pop("ee.wx")
wy = action.pop("ee.wy")
wz = action.pop("ee.wz")
gripper_pos = action.pop("ee.gripper_pos")
if None in (x, y, z, wx, wy, wz, gripper_pos):
raise ValueError(
"Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
)
ee_keys = [x, y, z, wx, wy, wz]
if self.gripper_names:
gripper_pos = action.pop("ee.gripper_pos")
ee_keys.append(gripper_pos)
if None in ee_keys:
raise ValueError("Missing required end-effector pose components in action")
observation = new_transition.get(TransitionKey.OBSERVATION).copy()
if observation is None:
raise ValueError("Joints observation is require for computing robot kinematics")
raise ValueError("Joints observation is required for computing robot kinematics")
q_raw = np.array(
[float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
[
float(v)
for k, v in observation.items()
if isinstance(k, str) and k.endswith(".pos") and k.removesuffix(".pos") in self.motor_names
],
dtype=float,
)
if q_raw is None:
raise ValueError("Joints observation is require for computing robot kinematics")
if self.initial_guess_current_joints: # Use current joints as initial guess
if self.initial_guess_current_joints:
self.q_curr = q_raw
else: # Use previous ik solution as initial guess
else:
if self.q_curr is None:
self.q_curr = q_raw
# Build desired 4x4 transform from pos + rotvec (twist)
t_des = np.eye(4, dtype=float)
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
t_des[:3, 3] = [x, y, z]
# Compute inverse kinematics
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
self.q_curr = q_target
# TODO: This is sentitive to order of motor_names = q_target mapping
for i, name in enumerate(self.motor_names):
if name != "gripper":
action[f"{name}.pos"] = float(q_target[i])
else:
action["gripper.pos"] = float(gripper_pos)
action[f"{name}.pos"] = float(q_target[i])
if self.gripper_names:
for gname in self.gripper_names:
action[f"{gname}.pos"] = float(gripper_pos)
new_transition[TransitionKey.ACTION] = action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
@@ -596,16 +632,22 @@ class InverseKinematicsRLStep(ProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
ee_feats = ["x", "y", "z", "wx", "wy", "wz"]
if self.gripper_names:
ee_feats.append("gripper_pos")
for feat in ee_feats:
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
for name in self.motor_names:
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
for name in self.gripper_names:
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features
def reset(self):
"""Resets the initial guess for the IK solver."""
self.q_curr = None

View File

@@ -18,7 +18,7 @@
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
remove features, modify tasks, and convert image datasets to video format.
remove features, modify tasks, recompute stats, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
Path semantics (v2): --root and --new_root are exact dataset folders containing
@@ -148,6 +148,21 @@ Show dataset information without feature details:
--operation.type info \
--operation.show_features false
Recompute dataset statistics:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type recompute_stats
Recompute stats for relative actions and push to hub:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--operation.num_workers 4 \
--push_to_hub true
Using JSON config file:
lerobot-edit-dataset \
--config_path path/to/edit_config.json
@@ -168,6 +183,7 @@ from lerobot.datasets.dataset_tools import (
delete_episodes,
merge_datasets,
modify_tasks,
recompute_stats,
remove_feature,
split_dataset,
)
@@ -230,6 +246,20 @@ class ConvertImageToVideoConfig(OperationConfig):
max_frames_per_batch: int | None = None
@OperationConfig.register_subclass("recompute_stats")
@dataclass
class RecomputeStatsConfig(OperationConfig):
skip_image_video: bool = True
relative_action: bool = False
relative_exclude_joints: list[str] | None = None
chunk_size: int = 50
num_workers: int = 0
relative_state: bool = False
relative_exclude_state_joints: list[str] | None = None
state_obs_steps: int = 2
derive_state_from_action: bool = False
@OperationConfig.register_subclass("info")
@dataclass
class InfoConfig(OperationConfig):
@@ -525,6 +555,47 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
logging.info("Dataset saved locally (not pushed to hub)")
def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, RecomputeStatsConfig):
raise ValueError("Operation config must be RecomputeStatsConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.info(f"Recomputing stats for {cfg.repo_id}")
if cfg.operation.relative_action:
logging.info(
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
f"exclude_joints={cfg.operation.relative_exclude_joints})"
)
if cfg.operation.relative_state:
logging.info(
f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, "
f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})"
)
if cfg.operation.derive_state_from_action:
logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)")
recompute_stats(
dataset,
skip_image_video=cfg.operation.skip_image_video,
relative_action=cfg.operation.relative_action,
relative_exclude_joints=cfg.operation.relative_exclude_joints,
chunk_size=cfg.operation.chunk_size,
num_workers=cfg.operation.num_workers,
relative_state=cfg.operation.relative_state,
relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints,
state_obs_steps=cfg.operation.state_obs_steps,
derive_state_from_action=cfg.operation.derive_state_from_action,
)
logging.info(f"Stats written to {dataset.root}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...")
dataset.push_to_hub()
def _get_dataset_size(repo_path):
import os
@@ -596,6 +667,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "recompute_stats":
handle_recompute_stats(cfg)
elif operation_type == "info":
handle_info(cfg)
else:

View File

@@ -65,6 +65,7 @@ def get_sys_info() -> dict[str, str]:
"Platform": platform.platform(),
"Python version": platform.python_version(),
"Huggingface Hub version": get_package_version("huggingface_hub"),
"Transformers version": get_package_version("transformers"),
"Datasets version": get_package_version("datasets"),
"Numpy version": get_package_version("numpy"),
"FFmpeg version": get_ffmpeg_version(),

View File

@@ -468,7 +468,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
try:
if cfg.resume:
dataset = LeRobotDataset(
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
@@ -476,13 +477,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
else:
# Create empty dataset or load existing saved episodes

View File

@@ -104,15 +104,13 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
robot.connect()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]

View File

@@ -252,10 +252,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
processor_pretrained_path = cfg.policy.pretrained_path
if (
getattr(cfg.policy, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
logging.warning(
"use_relative_actions=true with pretrained processors can skip relative transforms if "
"the checkpoint processors do not define them. Building processors from current policy config."
)
processor_pretrained_path = None
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
@@ -263,7 +275,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
if processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -285,7 +297,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)

View File

@@ -15,7 +15,7 @@
This script:
1. Loads action chunks from LeRobotDataset (with episode sampling)
2. Optionally applies delta transforms (relative vs absolute actions)
2. Optionally applies relative transforms (relative vs absolute actions)
3. Extracts specified action dimensions for encoding
4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes)
5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks
@@ -32,8 +32,8 @@ lerobot-train-tokenizer \
--max_episodes=100 \
--sample_fraction=0.1 \
--encoded_dims="0:6" \
--delta_dims="0,1,2,3,4,5" \
--use_delta_transform=true \
--relative_dims="0,1,2,3,4,5" \
--use_relative_transform=true \
--state_key="observation.state" \
--normalization_mode="QUANTILES" \
--vocab_size=1024 \
@@ -82,10 +82,10 @@ class TokenizerTrainingConfig:
sample_fraction: float = 0.1
# Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
encoded_dims: str = "0:6,7:23"
# Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
delta_dims: str | None = None
# Whether to apply delta transform (relative actions vs absolute actions)
use_delta_transform: bool = False
# Comma-separated dimension indices for relative transform (e.g., "0,1,2,3,4,5")
relative_dims: str | None = None
# Whether to apply relative transform (relative actions vs absolute actions)
use_relative_transform: bool = False
# Dataset key for state observations (default: "observation.state")
state_key: str = OBS_STATE
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
@@ -104,25 +104,27 @@ class TokenizerTrainingConfig:
hub_private: bool = False
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
"""Apply delta transform to specified dimensions.
def apply_relative_transform(
state: np.ndarray, actions: np.ndarray, relative_dims: list[int] | None
) -> np.ndarray:
"""Apply relative transform to specified dimensions.
Args:
state: Current state [D]
actions: Future actions [D]
delta_dims: List of dimension indices to apply delta transform to
relative_dims: List of dimension indices to apply relative transform to
Returns:
Transformed actions [D]
"""
if delta_dims is None or len(delta_dims) == 0:
if relative_dims is None or len(relative_dims) == 0:
return actions
delta_actions = actions.copy()
for dim in delta_dims:
delta_actions[dim] = actions[dim] - state[dim]
relative_actions = actions.copy()
for dim in relative_dims:
relative_actions[dim] = actions[dim] - state[dim]
return delta_actions
return relative_actions
def apply_normalization(
@@ -185,7 +187,7 @@ def apply_normalization(
def process_episode(args):
"""Process single episode and return action chunks."""
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
dataset, ep_idx, action_horizon, relative_dims, sample_fraction, state_key, use_relative_transform = args
try:
# get episode info
@@ -204,15 +206,15 @@ def process_episode(args):
for abs_idx in range(from_idx, to_idx):
# map absolute index to relative index if needed
if dataset._absolute_to_relative_idx is not None:
if abs_idx not in dataset._absolute_to_relative_idx:
if dataset.reader._absolute_to_relative_idx is not None:
if abs_idx not in dataset.reader._absolute_to_relative_idx:
# this episode's frames aren't in the filtered dataset
return None
rel_idx = dataset._absolute_to_relative_idx[abs_idx]
rel_idx = dataset.reader._absolute_to_relative_idx[abs_idx]
else:
rel_idx = abs_idx
frame = dataset.hf_dataset[rel_idx]
frame = dataset.get_raw_item(rel_idx)
# get state (could be from observation.state or other state key)
if state_key in frame:
@@ -222,7 +224,7 @@ def process_episode(args):
else np.array(frame[state_key])
)
else:
# if no state key, use zeros (no delta transform)
# if no state key, use zeros (no relative transform)
state = np.zeros_like(
frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
)
@@ -243,18 +245,18 @@ def process_episode(args):
current_state = states[i] # First state in chunk
future_absolute_actions = actions[i : i + action_horizon]
if use_delta_transform:
if use_relative_transform:
# relative actions
delta_chunk = np.zeros_like(future_absolute_actions)
relative_chunk = np.zeros_like(future_absolute_actions)
for t in range(action_horizon):
delta_chunk[t] = apply_delta_transform(
relative_chunk[t] = apply_relative_transform(
current_state,
future_absolute_actions[t],
delta_dims,
relative_dims,
)
action_chunks.append(delta_chunk)
action_chunks.append(relative_chunk)
else:
# absolute actions (no delta)
# absolute actions (no relative transform)
action_chunks.append(future_absolute_actions)
if len(action_chunks) == 0:
@@ -407,17 +409,20 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}")
# parse delta dimensions
delta_dim_list = None
if cfg.delta_dims is not None and cfg.delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")]
print(f"Delta dimensions: {delta_dim_list}")
# parse relative dimensions
relative_dim_list = None
if cfg.relative_dims is not None and cfg.relative_dims.strip():
relative_dim_list = [int(d.strip()) for d in cfg.relative_dims.split(",")]
print(f"Relative dimensions: {relative_dim_list}")
else:
print("No delta dimensions specified")
print("No relative dimensions specified")
print(f"Use delta transform: {cfg.use_delta_transform}")
if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
print(f"Use relative transform: {cfg.use_relative_transform}")
if cfg.use_relative_transform and (relative_dim_list is None or len(relative_dim_list) == 0):
print(
"Warning: use_relative_transform=True but no relative_dims specified. "
"No relative transform will be applied."
)
print(f"Action horizon: {cfg.action_horizon}")
print(f"State key: {cfg.state_key}")
@@ -440,10 +445,10 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
dataset,
ep_idx,
cfg.action_horizon,
delta_dim_list,
relative_dim_list,
cfg.sample_fraction,
cfg.state_key,
cfg.use_delta_transform,
cfg.use_relative_transform,
)
)
if chunks is not None:
@@ -544,9 +549,9 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
"encoded_dims": cfg.encoded_dims,
"encoded_dim_ranges": encoded_dim_ranges,
"total_encoded_dims": total_encoded_dims,
"delta_dims": cfg.delta_dims,
"delta_dim_list": delta_dim_list,
"use_delta_transform": cfg.use_delta_transform,
"relative_dims": cfg.relative_dims,
"relative_dim_list": relative_dim_list,
"use_relative_transform": cfg.use_relative_transform,
"state_key": cfg.state_key,
"normalization_mode": norm_mode.value,
"action_horizon": cfg.action_horizon,

View File

@@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ:
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/).
# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different
# dataset revisions are stored in isolated snapshot directories.
HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub"
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"

View File

@@ -95,6 +95,8 @@ def init_logging(
file_handler.setLevel(file_level.upper())
logger.addHandler(file_handler)
logging.getLogger("httpx").setLevel(logging.WARNING)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]

View File

@@ -80,7 +80,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
# indicating padding (those ending with "_is_pad")
dataset.delta_indices = None
dataset.reader.delta_indices = None
batch = next(iter(dataloader))
obs = {}
for k in batch:

View File

@@ -0,0 +1,385 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contract tests for LeRobotDatasetMetadata."""
import json
import numpy as np
import pytest
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.utils import INFO_PATH
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE
# ── helpers ──────────────────────────────────────────────────────────
SIMPLE_FEATURES = {
"state": {"dtype": "float32", "shape": (6,), "names": None},
"action": {"dtype": "float32", "shape": (6,), "names": None},
}
VIDEO_FEATURES = {
**SIMPLE_FEATURES,
"observation.images.laptop": {
"dtype": "video",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
IMAGE_FEATURES = {
**SIMPLE_FEATURES,
"observation.images.laptop": {
"dtype": "image",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
def _make_dummy_stats(features: dict) -> dict:
"""Create minimal episode stats matching the given features."""
stats = {}
for key, ft in features.items():
if ft["dtype"] in ("image", "video"):
stats[key] = {
"max": np.ones((3, 1, 1), dtype=np.float32),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
"min": np.zeros((3, 1, 1), dtype=np.float32),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
"count": np.array([5]),
}
elif ft["dtype"] in ("float32", "float64", "int64"):
stats[key] = {
"max": np.ones(ft["shape"], dtype=np.float32),
"mean": np.full(ft["shape"], 0.5, dtype=np.float32),
"min": np.zeros(ft["shape"], dtype=np.float32),
"std": np.full(ft["shape"], 0.25, dtype=np.float32),
"count": np.array([5]),
}
return stats
# ── Construction contracts ───────────────────────────────────────────
def test_create_produces_valid_info_on_disk(tmp_path):
"""create() writes info.json and the returned object reflects the provided settings."""
root = tmp_path / "new_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/meta",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
robot_type=DUMMY_ROBOT_TYPE,
root=root,
use_videos=False,
)
# info.json was written to disk
assert (root / INFO_PATH).exists()
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert meta.fps == DEFAULT_FPS
assert meta.robot_type == DUMMY_ROBOT_TYPE
assert "state" in meta.features
assert "action" in meta.features
assert info_on_disk["fps"] == DEFAULT_FPS
def test_create_starts_with_zero_counts(tmp_path):
"""A freshly created metadata has zero episode/frame/task counts."""
root = tmp_path / "empty_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/empty", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
assert meta.total_episodes == 0
assert meta.total_frames == 0
assert meta.total_tasks == 0
assert meta.tasks is None
assert meta.episodes is None
assert meta.stats is None
def test_create_with_videos_sets_video_path(tmp_path):
"""When features include video-dtype keys, create() produces a non-None video_path."""
root = tmp_path / "video_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/video", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=root, use_videos=True
)
assert meta.video_path is not None
assert len(meta.video_keys) == 1
assert "observation.images.laptop" in meta.video_keys
def test_create_without_videos_has_no_video_path(tmp_path):
"""When use_videos=False and no video features, video_path is None."""
root = tmp_path / "no_video"
meta = LeRobotDatasetMetadata.create(
repo_id="test/novid", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
assert meta.video_path is None
assert meta.video_keys == []
def test_create_raises_on_existing_directory(tmp_path):
"""create() raises if root directory already exists."""
root = tmp_path / "existing"
root.mkdir()
with pytest.raises(FileExistsError):
LeRobotDatasetMetadata.create(
repo_id="test/exists", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory, info_factory):
"""When metadata files exist on disk, __init__ loads them correctly."""
root = tmp_path / "load_test"
info = info_factory(total_episodes=3, total_frames=150, total_tasks=1, use_videos=False)
meta = lerobot_dataset_metadata_factory(root=root, info=info)
assert meta.total_episodes == 3
assert meta.total_frames == 150
assert meta.fps == info["fps"]
# ── Property accessors ───────────────────────────────────────────────
def test_property_accessors_reflect_info(tmp_path):
"""Properties return values consistent with the info dict."""
root = tmp_path / "props_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/props",
fps=DEFAULT_FPS,
features=IMAGE_FEATURES,
robot_type=DUMMY_ROBOT_TYPE,
root=root,
use_videos=False,
)
assert meta.fps == DEFAULT_FPS
assert meta.robot_type == DUMMY_ROBOT_TYPE
# shapes should be tuples
for _key, shape in meta.shapes.items():
assert isinstance(shape, tuple)
# image_keys should contain the image feature
assert "observation.images.laptop" in meta.image_keys
# camera_keys is a superset of image_keys and video_keys
assert set(meta.image_keys + meta.video_keys) == set(meta.camera_keys)
def test_data_path_is_formattable(tmp_path):
"""data_path contains format placeholders that can be .format()-ed."""
root = tmp_path / "fmt_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/fmt", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
formatted = meta.data_path.format(chunk_index=0, file_index=0)
assert "chunk" in formatted.lower() or "0" in formatted
# ── Task management ──────────────────────────────────────────────────
def test_save_episode_tasks_creates_tasks_dataframe(tmp_path):
"""On a fresh metadata, save_episode_tasks() creates the tasks DataFrame."""
root = tmp_path / "task_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/task", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
assert meta.tasks is None
meta.save_episode_tasks(["Pick up the cube"])
assert meta.tasks is not None
assert len(meta.tasks) == 1
assert "Pick up the cube" in meta.tasks.index
def test_save_episode_tasks_is_additive(tmp_path):
"""New tasks are added; existing tasks keep their original index."""
root = tmp_path / "additive_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/add", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Task A"])
idx_a = meta.get_task_index("Task A")
meta.save_episode_tasks(["Task A", "Task B"])
assert meta.get_task_index("Task A") == idx_a # unchanged
assert meta.get_task_index("Task B") is not None
assert len(meta.tasks) == 2
def test_get_task_index_returns_none_for_unknown(tmp_path):
"""get_task_index() returns None for an unknown task."""
root = tmp_path / "unknown_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/unknown", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Known task"])
assert meta.get_task_index("Known task") == 0
assert meta.get_task_index("Unknown task") is None
def test_save_episode_tasks_rejects_duplicates(tmp_path):
"""save_episode_tasks() raises ValueError on duplicate task strings."""
root = tmp_path / "dup_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/dup", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
with pytest.raises(ValueError):
meta.save_episode_tasks(["Same task", "Same task"])
# ── Episode saving ───────────────────────────────────────────────────
def test_save_episode_increments_counters(tmp_path):
"""After save_episode(), total_episodes and total_frames increase."""
root = tmp_path / "ep_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/ep", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Task 1"])
stats = _make_dummy_stats(meta.features)
meta.save_episode(
episode_index=0,
episode_length=10,
episode_tasks=["Task 1"],
episode_stats=stats,
episode_metadata={},
)
assert meta.total_episodes == 1
assert meta.total_frames == 10
def test_save_episode_updates_stats(tmp_path):
"""After save_episode(), .stats is non-None and has feature keys."""
root = tmp_path / "stats_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/stats", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Task 1"])
stats = _make_dummy_stats(meta.features)
meta.save_episode(
episode_index=0,
episode_length=5,
episode_tasks=["Task 1"],
episode_stats=stats,
episode_metadata={},
)
assert meta.stats is not None
# Stats should contain at least the user-defined feature keys
for key in SIMPLE_FEATURES:
assert key in meta.stats
# ── Chunk settings ───────────────────────────────────────────────────
def test_update_chunk_settings_persists(tmp_path):
"""update_chunk_settings() changes values and writes info.json."""
root = tmp_path / "chunk_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/chunk", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
original = meta.get_chunk_settings()
meta.update_chunk_settings(chunks_size=500)
assert meta.chunks_size == 500
assert meta.chunks_size != original["chunks_size"] or original["chunks_size"] == 500
# Verify persisted
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert info_on_disk["chunks_size"] == 500
def test_update_chunk_settings_rejects_non_positive(tmp_path):
"""update_chunk_settings() raises ValueError for <= 0 values."""
root = tmp_path / "bad_chunk"
meta = LeRobotDatasetMetadata.create(
repo_id="test/bad", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
with pytest.raises(ValueError):
meta.update_chunk_settings(chunks_size=0)
with pytest.raises(ValueError):
meta.update_chunk_settings(data_files_size_in_mb=-1)
# ── Finalization ─────────────────────────────────────────────────────
def test_finalize_is_idempotent(tmp_path):
"""Calling finalize() multiple times does not raise."""
root = tmp_path / "fin_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/fin", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.finalize()
meta.finalize() # second call should not raise
def test_finalize_flushes_buffered_metadata(tmp_path):
"""Episodes saved before finalize() are written to parquet."""
root = tmp_path / "flush_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/flush",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
metadata_buffer_size=100, # large buffer so nothing auto-flushes
)
meta.save_episode_tasks(["Task 1"])
stats = _make_dummy_stats(meta.features)
# Save a few episodes (won't auto-flush since buffer_size=100)
for i in range(3):
meta.save_episode(
episode_index=i,
episode_length=5,
episode_tasks=["Task 1"],
episode_stats=stats,
episode_metadata={},
)
# Before finalize, the parquet might not exist yet
meta.finalize()
# After finalize, episodes parquet should exist
episodes_dir = root / "meta" / "episodes"
assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contract tests for DatasetReader."""
from lerobot.datasets.dataset_reader import DatasetReader
from lerobot.datasets.video_utils import get_safe_default_codec
# ── Loading ──────────────────────────────────────────────────────────
def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factory):
"""Given a fully written dataset, try_load() returns True."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
)
reader = DatasetReader(
meta=dataset.meta,
root=dataset.root,
episodes=None,
tolerance_s=1e-4,
video_backend=get_safe_default_codec(),
delta_timestamps=None,
image_transforms=None,
)
assert reader.try_load() is True
assert reader.hf_dataset is not None
def test_try_load_returns_false_when_no_data(tmp_path):
"""When only metadata exists (no data/ parquets), try_load() returns False."""
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
root = tmp_path / "meta_only"
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
meta = LeRobotDatasetMetadata.create(
repo_id="test/meta_only", fps=30, features=features, root=root, use_videos=False
)
reader = DatasetReader(
meta=meta,
root=meta.root,
episodes=None,
tolerance_s=1e-4,
video_backend=get_safe_default_codec(),
delta_timestamps=None,
image_transforms=None,
)
assert reader.try_load() is False
assert reader.hf_dataset is None
# ── Counts ───────────────────────────────────────────────────────────
def test_num_frames_without_filter(tmp_path, lerobot_dataset_factory):
"""With episodes=None, num_frames equals total_frames."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False
)
assert dataset.reader.num_frames == dataset.meta.total_frames
def test_num_episodes_without_filter(tmp_path, lerobot_dataset_factory):
"""With episodes=None, num_episodes equals total_episodes."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False
)
assert dataset.reader.num_episodes == dataset.meta.total_episodes
def test_num_frames_with_episode_filter(tmp_path, lerobot_dataset_factory):
"""When filtering to a subset, only those episodes' frames are counted."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=5, total_frames=100, episodes=[0, 2], use_videos=False
)
# Filtered frames should be less than total
assert dataset.reader.num_frames <= dataset.meta.total_frames
assert dataset.reader.num_episodes == 2
# ── get_item ─────────────────────────────────────────────────────────
def test_get_item_returns_expected_keys(tmp_path, lerobot_dataset_factory):
"""get_item(0) returns a dict with expected keys."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
)
item = dataset.reader.get_item(0)
# Standard keys that must always be present
for key in ["index", "episode_index", "frame_index", "timestamp", "task_index", "task"]:
assert key in item, f"Missing key: {key}"
def test_get_item_values_are_correct(tmp_path, lerobot_dataset_factory):
"""get_item() returns correct index and episode_index."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
)
item_0 = dataset.reader.get_item(0)
assert item_0["index"].item() == 0
assert item_0["episode_index"].item() == 0
# ── Transforms ───────────────────────────────────────────────────────
def test_image_transforms_are_applied(tmp_path, lerobot_dataset_factory):
"""When image_transforms is provided, get_item() applies it to camera keys."""
transform_called = {"count": 0}
def sentinel_transform(img):
transform_called["count"] += 1
return img
dataset = lerobot_dataset_factory(
root=tmp_path / "ds",
total_episodes=1,
total_frames=5,
use_videos=False,
image_transforms=sentinel_transform,
)
item = dataset[0] # noqa: F841
# Should have been called once per camera key per frame
num_cameras = len(dataset.meta.camera_keys)
if num_cameras > 0:
assert transform_called["count"] >= 1
# ── File paths ───────────────────────────────────────────────────────
def test_get_episodes_file_paths_returns_data_paths(tmp_path, lerobot_dataset_factory):
"""get_episodes_file_paths() returns paths including data/ paths."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
)
paths = dataset.reader.get_episodes_file_paths()
assert len(paths) > 0
assert any("data/" in str(p) for p in paths)
def test_get_episodes_file_paths_includes_video_paths(tmp_path, lerobot_dataset_factory):
"""When dataset has video keys, file paths include video/ paths."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=True
)
if len(dataset.meta.video_keys) > 0:
paths = dataset.reader.get_episodes_file_paths()
assert any("video" in str(p).lower() for p in paths)

Some files were not shown because too many files have changed in this diff Show More