Compare commits

..

31 Commits

Author SHA1 Message Date
Pepijn
46e9e22b05 feat(eval): thread-safe policy copies for max_parallel_tasks > 1
eval_policy_all already supports running multiple task groups concurrently via
ThreadPoolExecutor, but policy.reset() was not thread-safe: all threads shared
the same policy object and its mutable state (action queues, temporal buffers).

Fix: each thread receives a shallow copy of the policy. copy.copy() creates a
new Python object whose _parameters dict is a shared reference — same tensor
storage, zero extra VRAM — while reset() rebinds per-episode state to fresh
objects per thread.

Caveat: ACT with temporal_ensemble_coeff is not safe with this approach (its
reset() mutates a shared sub-object). Keep max_parallel_tasks=1 for that config.

For MetaWorld (50 tasks, no temporal ensembling), max_parallel_tasks=4 raises
GPU utilization from ~20% to ~60-80% with no additional VRAM cost.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-03 17:11:36 +02:00
Pepijn
b43f9ab048 feat(envs): lazy env init + AsyncVectorEnv as default for n_envs > 1
LiberoEnv and MetaworldEnv previously allocated GPU resources (EGL context,
OpenGL framebuffer) in __init__, before AsyncVectorEnv's fork(). Worker
processes inherited stale GPU handles, causing EGL_BAD_CONTEXT crashes on
first render.

Fix: defer OffScreenRenderEnv / MT1 construction to _ensure_env(), called on
first reset() or step() inside the worker subprocess. Each worker creates its
own clean context after fork().

Also fixes lerobot_eval.py:170 (add_envs_task TODO): replace with
env.call("task") which works with both SyncVectorEnv and AsyncVectorEnv.

AsyncVectorEnv is now the default for n_envs > 1; auto-downgraded to
SyncVectorEnv when n_envs=1 (no benefit, less overhead).

Expected speedup: ~15-20x for LIBERO Spatial with batch_size=50.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-03 17:10:10 +02:00
Pepijn
0045f88355 merge: resolve conflicts from main into refactor/benchmark-dispatch
Keep refactored dispatch pattern (no factory.py edits for new benchmarks).
Incorporate main's "Verifying your integration" section and class naming fix.

Made-with: Cursor
2026-04-03 14:49:36 +02:00
Pepijn
4dbbcca496 docs(benchmarks): add benchmark integration guide and standardize benchmark docs (#3270)
* docs(benchmarks): add benchmark integration guide and standardize benchmark docs

Add a comprehensive guide for adding new benchmarks to LeRobot, and
refactor the existing LIBERO and Meta-World docs to follow the new
standardized template.

Made-with: Cursor

* docs(benchmarks): clean up adding-benchmarks guide for clarity

Rewrite for simpler language, better structure, and easier navigation.
Move quick-reference table to the top, fold eval explanation into
architecture section, condense the doc template to a bulleted outline.

Made-with: Cursor

* fix link

* fix task count

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/metaworld.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update docs/source/adding_benchmarks.mdx

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* docs(benchmarks): add verification checklist to adding-benchmarks guide

Made-with: Cursor

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-04-03 14:44:53 +02:00
Pepijn
89ce91f69f Merge branch 'docs/adding-benchmarks-guide' into refactor/benchmark-dispatch 2026-04-03 13:56:49 +02:00
Pepijn
90e614f6b9 fix task count 2026-04-03 13:48:37 +02:00
Pepijn
ff4f860e5d fix link 2026-04-03 13:47:17 +02:00
Pepijn
6f2823bfc4 merge: resolve conflicts with docs/adding-benchmarks-guide
Incorporate cleaner writing from the docs branch while reflecting the
refactored dispatch pattern (no factory.py edits needed for new benchmarks).

Made-with: Cursor
2026-04-03 13:45:12 +02:00
Pepijn
77415559b8 docs(benchmarks): clean up adding-benchmarks guide for clarity
Rewrite for simpler language, better structure, and easier navigation.
Move quick-reference table to the top, fold eval explanation into
architecture section, condense the doc template to a bulleted outline.

Made-with: Cursor
2026-04-03 13:36:16 +02:00
Pepijn
24d9b74d81 refactor(envs): move dispatch logic from factory into EnvConfig subclasses
Replace hardcoded if/elif chains in factory.py with create_envs() and
get_env_processors() methods on EnvConfig. New benchmarks now only need
to register a config subclass — no factory.py edits required.

Net -23 lines: factory.py shrinks from ~200 to ~70 lines of logic.

Made-with: Cursor
2026-04-03 13:23:44 +02:00
Pepijn
508358749a docs(benchmarks): add benchmark integration guide and standardize benchmark docs
Add a comprehensive guide for adding new benchmarks to LeRobot, and
refactor the existing LIBERO and Meta-World docs to follow the new
standardized template.

Made-with: Cursor
2026-04-02 20:43:31 +02:00
Pepijn
818892a38b feat(dagger): Add HIL/Dagger/HG-Dagger/RaC style data collection (#2833)
* feat: HIL data collection, RTC interpolator, and action queue improvements

- Add Human-in-the-Loop (HIL) data collection examples (sync + RTC)
- Add HIL data collection documentation
- Add ActionInterpolator for smoother policy control at higher rates
- Integrate interpolator into lerobot-record and eval_with_real_robot
- Add action queue clear() and get_processed_left_over() methods
- Add rtc/__init__.py for cleaner imports

* docs: expand Related Work section with paper summaries

* fix: only record dataset frames at original fps, not at interpolated rate

The interpolator speeds up robot control (e.g. 2x) but dataset frames
should still be recorded at the original fps. Interpolated-only
iterations now only send actions to the robot without writing to the
dataset.

* refactor: merge HIL sync and RTC scripts into single file with --rtc.enabled toggle

Combines hil_data_collection.py and hil_data_collection_rtc.py into one
script. RTC is toggled via --rtc.enabled=true (defaults to off for sync
inference). Deletes the separate hil_data_collection_rtc.py and updates
docs to reflect the single-script usage.

* test: add ActionInterpolator test suite (29 tests)

Covers constructor validation, passthrough (multiplier=1), 2x and 3x
interpolation with exact value checks, reset/episode boundaries,
control interval calculation, multi-dim actions, and simulated
control loop integration.

* test: add ActionQueue + ActionInterpolator integration tests

Verifies the interpolator doesn't interfere with RTC's leftover chunk
tracking: queue consumption rate matches base fps regardless of
multiplier, get_left_over/get_processed_left_over only change on
queue.get(), merge preserves smooth interpolation across chunks,
and interpolator reset is independent of queue state.

* feat: register SO follower/leader configs in HIL script

Adds SOFollowerRobotConfig and SOLeaderTeleopConfig imports so
SO100/SO101 robots can be used via --robot.type=so_follower
and --teleop.type=so_leader. Updates docs accordingly.

Made-with: Cursor

* docs: remove em dashes from HIL documentation

Made-with: Cursor

* refactor: rename examples/rac to examples/hil

Updates directory name and all references in docs and script docstrings.

Made-with: Cursor

* fix: encorperate pr feedback comments

* refactor(tests): enhance ActionInterpolator test structure and add detailed docstrings

* feedback pr and test fix

* fix(test): pass correct real_delay in interpolator delay test

The test was passing real_delay=0 and relying on _check_delays to
silently override it with the index-based diff. Now passes real_delay=3
to match the 3 actions consumed during the simulated inference period.


* fix pr feedback

* ordering

* update hil script

* fix

* default name

* fix(bi_openarm): use kw_only=True to fix dataclass field ordering

BiOpenArmFollowerConfig overrides `id` with a default, making it
positional in the child — non-default `left_arm_config` then follows a
default field, which Python dataclasses forbid. Adding kw_only=True
(matching the parent RobotConfig) removes positional constraints.

Made-with: Cursor

* style: format long line in hil_data_collection.py

Made-with: Cursor

* pr feedback

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-04-02 19:53:59 +02:00
Pepijn
66fef25ded docs(toctree): add Benchmarks section for LIBERO and Meta-World (#3268)
* docs(toctree): add Benchmarks section for LIBERO and Meta-World

Move LIBERO and Meta-World pages out of the Simulation section into a
dedicated Benchmarks section so benchmark-specific docs are easier to
find and the Simulation section stays focused on environment hubs.

Made-with: Cursor

* docs(toctree): move IsaacLab Arena into Benchmarks section

Include NVIDIA IsaacLab Arena Environments alongside LIBERO and
Meta-World in the Benchmarks section.

Made-with: Cursor
2026-04-02 19:52:39 +02:00
Pepijn
2cf08b7a4b Add create reward visualization (#3155)
* Add create reward visualization and multimodal analysis tool

* add example for creating progress video for sarm

* nit

* precommit

* refactor: address review comments on create_progress_videos.py

- Add shebang and Apache 2.0 license header
- Replace hardcoded absolute OUTPUT_DIR with relative default (./progress_videos)
- Add argparse CLI (--repo-id, --episode, --camera-key, --output-dir, --gif)
- Wrap entrypoint in def main()
- Replace all print() with logging
- Use logging.error/warning instead of traceback.print_exc
- Release VideoCapture via try/finally; consolidate triple-open into single seek
- Eliminate intermediate clip file: seek directly via CAP_PROP_POS_MSEC
- Make MP4 the default output, GIF opt-in via --gif flag
- Add return types to all functions
- Add Args/Returns docstrings
- Use descriptive variable names throughout

Made-with: Cursor

* refactor: move create_progress_videos.py to examples/dataset/ for consistency

Made-with: Cursor

* refactor: address PR review comments on create_progress_videos.py

- Replace Unicode ellipsis and multiplication sign with ASCII equivalents
- Fix step numbering from 1-5 to 1-4 (only 4 actual steps)
- Move frame_width reading into convert_mp4_to_gif
- Remove unused text_height variable

Made-with: Cursor
2026-04-02 16:58:07 +02:00
Pepijn
15934d8d08 feat(policies): add relative action support for pi0, pi0.5, and pi0_fast (#2970)
* Add option for pi family models to train with relative actions (relative to state)

* formatting

* add recomputation of stats and option to compute delta stats

* normalzie after delta conversion

* only recompute state for stats

* calulate chunk based stats

* sample 100k

* load from parquet

* sample 1m

* stats per chunck

* fix

* use quantiles

* stats for entire dataset

* fix

* max 1m frames

* compute before dist

* fix multi gpu processor bug

* Fix RTC with delta actions and OpenArms motor_type wiring

* feat: align pi0_fast delta actions with pi0/pi05 and add RTC integration tests

- Add delta_exclude_joints and action_feature_names to PI0FastConfig
- Move to_absolute_actions from modeling to processor pipeline for pi0_fast
- Add delta action detection and logging to eval_with_real_robot.py
- Add delta actions documentation to pi0 and pi05 READMEs
- Fix ruff lint issues in test_delta_actions.py
- Add test_rtc_delta_actions.py (24 tests) covering:
  - ActionQueue with delta vs absolute actions
  - RTC denoise step with delta leftovers
  - Full pipeline roundtrip (delta → RTC → absolute)
  - State rebasing approximation bounds
  - Non-delta policy compatibility
  - Multi-chunk consistency

* chore: clean up test comments, add OpenPI attribution, remove debug logging

- Replace decorative comment separators in test files with plain section headers
- Add attribution comments for 1e-6 epsilon in normalize_processor.py (from OpenPI)
- Remove debug logging blocks from lerobot_train.py

* refactor: extract compute_delta_action_stats into compute_stats.py

Move the ~70-line inline delta action stats block from lerobot_train.py
into a dedicated function in compute_stats.py, where all other stats
computation already lives. The training script now calls it in 6 lines.

* refactor: remove unused get_processed_left_over from ActionQueue

This method was never called outside of tests. Leftover actions for RTC
guidance are always retrieved via get_left_over() (delta/original space).

* revert: remove logging-only changes from eval_with_real_robot.py

The delta actions detection helper and log message added no functional
value — the script already handles delta policies correctly via the
processor pipeline.

* refactor: use ACTION/OBS_STATE constants instead of hardcoded strings

Replace hardcoded "action" and "observation.state" with ACTION and
OBS_STATE from utils.constants in compute_stats.py, dataset_tools.py,
and lerobot_train.py.

* style: remove stray blank lines in training loop

* refactor: move delta action stats to preprocessing step, remove on-the-fly computation

- Remove on-the-fly compute_delta_action_stats from lerobot_train.py
- Rewrite recompute_stats to delegate action stats to compute_delta_action_stats
  (chunk-based sampling matching what the model sees during training)
- Add chunk_size parameter to recompute_stats for delta action computation
- Add delta actions documentation to pi0.mdx and pi05.mdx

* feat: add recompute_stats CLI operation to lerobot-edit-dataset

* fix(tests): relax quantile normalization test tolerance for 1e-6 epsilon

* chore: remove agents_memory/pr_details.md from repo

* refactor: rename delta actions to relative actions throughout

What OpenPI calls "DeltaActions" is actually UMI's "relative trajectory"
representation: each action in the chunk is an offset from the current
state, not from the previous action. This avoids error accumulation.

Renamed across all source, tests, docs, and CLI:
- DeltaActionsProcessorStep → RelativeActionsProcessorStep
- to_delta_actions → to_relative_actions
- use_delta_actions → use_relative_actions
- delta_exclude_joints → relative_exclude_joints
- compute_delta_action_stats → compute_relative_action_stats
- delta_action_processor.py → relative_action_processor.py
- test_delta_actions.py → test_relative_actions.py

Kept as-is: AbsoluteActionsProcessorStep (converts TO absolute),
registry ID "delta_actions_processor" (backward compat), and unrelated
delta references (IK pipeline, Robosuite, RA-BC metrics, gym envs).

* docs: add Action Representations guide

Dedicated page explaining absolute, relative, and delta actions with
numerical examples, joint vs EE space, and how to use kinematics
pipelines and the relative action processor. References UMI paper
(Chi et al., 2024) for the terminology.

* docs: remove redundant OpenPI naming note from action representations

* docs: remove opinionated OpenPI reference from delta actions section

* docs: replace ASCII diagram with UMI paper figure

* docs: remove OpenPI reference from action representations

* docs: use HF-hosted image instead of local asset

* docs: clarify figure attribution

* revert: restore original normalization epsilon behavior

The 1e-6 unconditional epsilon change perturbed all normalized values,
breaking backward compatibility tests. The original approach (1e-8 eps
for MEAN_STD, conditional torch.where for QUANTILES) already handles
division by zero correctly without affecting non-degenerate cases.

* fix: restore delta_action_processor.py used by phone/RL teleop

The rename commit incorrectly deleted delta_action_processor.py and
duplicated its classes into relative_action_processor.py. Restore the
original file and import from it instead.

* fix(processor): address PR #2970 review comments

- Remove shebang from relative_action_processor.py (library module, not script)
- Add device alignment in to_relative_actions/to_absolute_actions so _last_state
  on CPU doesn't cause cross-device errors when actions are on CUDA
- Rename delta_step → relative_step in AbsoluteActionsProcessorStep for naming
  consistency; update factory.py, all processor files, and tests
- Expand _reconnect_relative_absolute_steps docstring to explain why post-hoc
  rewiring is needed after deserialization
- Fix off-by-one in compute_stats.py: sample_upper_bound = total_frames - chunk_size + 1
  so last valid start index is included and total_frames == chunk_size is not rejected
- Remove redundant NOTE comment in processor_pi05.py (duplicated two lines below)
- Fix pi0_fast processor ordering: move relative_step before NormalizerProcessorStep
  so normalizer sees delta actions (matching pi0/pi05); flip postprocessor to
  unnormalize → absolute accordingly. Relative stats are now required for all pi models
- Revert use_relative_joint_actions_aloha → use_delta_joint_actions_aloha in
  configuration_smolvla.py (preserve existing public API)
- Update action_representations.mdx: add missing joint to 6-DOF example, fix
  'based on a figure', clarify pi family ordering, add RTC compatibility section

* update rtc link

* feat: compute relative action stats over full dataset with optional parallelism

Remove the 100k sample cap from compute_relative_action_stats and process
all valid chunks. Vectorize with numpy (pre-load actions/states, fancy
indexing + broadcasting) for a large speedup over the per-index HF dataset
loop. Add num_workers param for thread-based parallelism (numpy releases
the GIL). Update docs to show --push_to_hub for recompute_stats.

* style: apply ruff formatting to compute_stats.py

* testing on real robot

* style: fix ruff format and remove redundant .keys() calls
2026-04-01 12:59:12 +02:00
Jai Kumaar Ratadia
9300352876 Fix SO-101 assembly instruction order to match videos (#3242)
* Fix SO-101 assembly instruction order to match videos

Motor horn installation steps were listed after placing motors
into the housing, but the assembly videos show installing horns
first. Reordered steps to match the videos, which is also the
easier approach since horns are harder to attach once the motor
is seated. Added missing detail that bottom horns don't require
screws.

* Update docs/source/so101.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>

---------

Signed-off-by: Jai Kumaar Ratadia <jaikumaarratadia@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-03-31 12:16:34 +02:00
Steven Palma
720cf8e3a0 Revert "fix(deps): breaking change from transformers 5.4.0" (#3249)
* Revert "fix(deps): breaking change from transformers 5.4.0 (#3231)"

This reverts commit 07502868e5.

* chore(dependecies): pin transformers to 5.3.0 temporarily
2026-03-30 19:11:41 +02:00
Steven Palma
5d4fdf5088 feat(scripts): add transformers version (#3248)
* feat(scripts): add transformers and torch version

* chore(scripts): remove pytorch
2026-03-30 16:33:17 +02:00
四七
3b185f7f9d fix(datasets): remove unreachable timestamp branch in add_frame (#3163)
* fix(datasets): remove unreachable timestamp branch in add_frame and document caller contract

- Remove dead code: frame.pop("timestamp") branch in add_frame() could never
  execute because validate_frame() raises ValueError for any DEFAULT_FEATURES
  key (including timestamp) before we reach that line.
- Expand add_frame() docstring: explicitly document that timestamp and
  frame_index must NOT be passed by the caller.
- Add explanatory comment in validate_frame(): clarifies why DEFAULT_FEATURES
  are excluded from expected_features, preventing future re-introduction of
  the dead branch.

The dead branch originated in #1200, which fixed a shape-(1,) mismatch for a
code path that was subsequently made unreachable by a refactor of validate_frame.

* chore(datasets): narrow PR scope

* fix(datasets): move add_frame timestamp cleanup to dataset_writer
2026-03-28 11:37:57 +01:00
Bryson Jones
2e069b1c47 Feature/add multitask diffusion transformer policy implementation (#2545)
* Add multitask diffusion transformer policy

Add multitask diffusion transformer policy

* expand the observation encoder to support differnt size encoders for vision and text

* add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs

* update readme and citations for multitask dit policy

* remove dino vision encoder and simplify text and vision encoders by removing inheritance structure

* adjust factory comment

* update docstring for multitask dit policy processor file

* simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

* add references to the modeling file comments

* merge all modules files into the main modeling file

* add torch.no_grad decorators

* split up select action return statement

* remove redundant asserts

* add tutorial to training with multi_task_dit

* fix bugs when testing on hardware

* remove environment state conditioning

* update typo in test instruction comment

* add processor tests to multitask dit tests

* move policy to top of file

* use constants for indexing into batches and remove env state references

* remove the base classes since we don't need to be able to extend

* fix nit formatting in generate actions fcn

* reformat and clean up tutorial for multitask dit policy

* add more descriptions and depth to multitask dit tutorial

* note origins of each training objective

* rename config param for multiple vision encoders

* refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit

* add multitask dit to toc for docs

* add conditional transformers import to match all other policies that use transformers lib

* add test handling for multitask dit when transformers isnt available

* skip tests without transformers

* remove cropping of images smaller than the crop size

* add kwargs arg to multitask dit constructor

* add wallx dep conflict management for multitask dit policy

* use hyphens for cleanliness in pyproject.toml

* add conflict management to pyproject toml for pi conflict for mtdp as well

* update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time

* revert fast tests edits

* update docs and readme files, fixing some typos and adding multitask dit to readme

* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy

* chore(dependencies): bump pi0 family to transformers v5

* chore(dependencies): bump wall x to transformers v5

* chore(dependencies): bump gr00t to transformers v5

* chore(style): fix pre-commit

* fix(policy): xvla forced_bos_token missing

* test(rl): skip ci tests for resnet10

* Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>

* chore: fix XVLA in transformers v5 (#3006)

* test(policies): enable wall x CI testing

* style(test): pre-commit check

* style(test): pre-commit

---------

Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Yufei Sun <skieyfly@gmail.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-28 00:41:26 +01:00
Steven Palma
4e45acca52 fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)
* refactor(dataset): enhance dataset root directory handling and introduce hub cache support

- Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads.
- Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management.
- Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory.

* refactor(dataset): improve root directory handling in LeRobotDataset

- Updated LeRobotDataset to store the requested root path separately from the actual root path.
- Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management.

* refactor(dataset): minor improvements for hub cache support

* chore(datasets): guard in resume + assertion test

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
2026-03-27 22:21:55 +01:00
Maxime Ellerbach
975d89b38d chore(docs): add more guidance to bring your own policies tutorial (#3230)
* chore(docs): add more guidance to bring your own policies tutorial

* removing normalization to avoid confusion with processors

* trailing whitespace

* Update docs/source/bring_your_own_policies.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update docs/source/bring_your_own_policies.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding get optim params and predict_action chunk

* removing extra quotes

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-03-27 21:25:37 +01:00
Maxime Ellerbach
07502868e5 fix(deps): breaking change from transformers 5.4.0 (#3231)
* fix(deps): breaking change from transformers 5.4.0

* Update src/lerobot/policies/xvla/modeling_florence2.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* removing dataclass

* bumping transformers 5.4.0

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-27 21:25:12 +01:00
Reece O'Mahoney
aa9cc9bd43 fix(logging): suppress noisy httpx INFO logs (#3173)
Set httpx logger level to WARNING in init_logging to prevent
HTTP request traces from flooding the terminal during train and
eval scripts.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-26 21:05:15 +01:00
Steven Palma
123495250b refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180)
* refactor(dataset): split reader and writer

* chore(dataset): remove proxys

* refactor(dataset): better reader & writer encapsulation

* refactor(datasets): clean API + reduce leaky implementations

* refactor(dataset): API cleaning for writer, reader and meta

* refactor(dataset): expose writer & reader + other minor improvements

* refactor(dataset): improve teardown routine

* refactor(dataset): add hf_dataset property at the facade level

* chore(dataset): add init for datasset module

* docs(dataset): add docstrings for public API of the dataset classes

* tests(dataset): add tests for new classes

* fix(dataset): remove circular dependecy
2026-03-26 19:09:25 +01:00
Jade Choghari
017ff73fbf chore(docs): add rename map and empty cam guide (#3065)
* add blog/guide

* add to tree

* chore(docs): rephrase rename_map docs for clarity and simplicity

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-23 13:57:53 -07:00
Praedico
f90db58c15 docs(async): fix GitHub issues link (#3186) 2026-03-19 22:32:07 -07:00
Altman
e64fa667c3 fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors (#3128)
* fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors

When VQ discretization phase completes, the code was overwriting
register_buffer('discretized') and register_buffer('freeze_codebook')
with torch.tensor(True), which is created on CPU. DDP then fails in
_sync_buffers() with: RuntimeError: No backend type associated with
device type cpu. Fix by updating the buffers in-place with .fill_(True)
so device and registration are preserved.

Made-with: Cursor

* test(vqbet): add regression test for in-place buffer update during discretization

Verifies that discretize() updates the 'discretized' and 'freeze_codebook'
registered buffers in-place (via fill_()) rather than replacing them with new
CPU tensors. The test checks data_ptr() identity and that the tensors remain
registered buffers after the call. This prevents regressions of the DDP fix.

Made-with: Cursor

* test(vqbet): add GPU regression test to verify buffers stay on CUDA after discretize()

Directly catches the original DDP failure mode: when buffers are replaced with
torch.tensor(True) they land on CPU, causing NCCL to raise 'No backend type
associated with device type cpu' in _sync_buffers(). The GPU test places the
model on cuda:0 and asserts both buffers remain on CUDA after discretization.

Made-with: Cursor

* test(vqbet): simplify to single device-check test in test_policies.py

Per reviewer feedback: remove the separate test file and replace the two
CPU/GPU tests (with data_ptr checks) with a single focused test in
tests/policies/test_policies.py that only asserts the registered buffers
remain on the model device after discretize(). Uses DEVICE from tests/utils.py
so it runs on whatever device the CI/user selects (cpu, cuda, mps).

Made-with: Cursor

* style: fix import order in test_policies.py to pass ruff/pre-commit checks

Made-with: Cursor

---------

Co-authored-by: Zhan DiJia <2476100824@example.com>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-03-18 13:24:07 +01:00
Khalil Meftah
d9ec3a6fa2 Fix/earth rover dataset features (#3088)
* docs(earthrover): update EarthRover Mini Plus dataset features and descriptions

* refactor(teleop): rename rover action keys to linear_velocity/angular_velocity

* fix(earthrover): align observation and action features with frodobots/berkeley-frodobots-lerobot-7k

* chore: address PR review comments

* ci: retrigger checks
2026-03-17 18:33:53 +01:00
Steven Palma
d90e4bcfd3 refactor(dataset): modular files (#3171)
* refactor(dataset): modular files

* refactor(dataset): update imports across the codebase
2026-03-15 23:58:09 -07:00
Steven Palma
9d3b62aa61 chore(dataset): basic house-keeping (#3170) 2026-03-15 22:12:09 -07:00
99 changed files with 12251 additions and 1787 deletions

View File

@@ -100,11 +100,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub

View File

@@ -17,8 +17,12 @@
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
- local: hil_data_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: lerobot-dataset-v3
@@ -47,6 +51,8 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"
@@ -65,13 +71,17 @@
title: Environments from the Hub
- local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac)
title: "Simulation"
- sections:
- local: adding_benchmarks
title: Adding a New Benchmark
- local: libero
title: LIBERO
- local: metaworld
title: Meta-World
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: libero
title: Using Libero
- local: metaworld
title: Using MetaWorld
title: "Simulation"
title: "Benchmarks"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors
@@ -83,6 +93,8 @@
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
- local: action_representations
title: Action Representations
title: "Robot Processors"
- sections:
- local: so101

View File

@@ -0,0 +1,223 @@
# Action Representations
This guide explains the different ways robot actions can be represented in LeRobot, how they relate to each other, and when to use each one.
## Joint Space vs End-Effector Space
Before discussing action representations, it helps to understand the two coordinate spaces actions can live in.
### Joint Space
Joint-space actions directly specify target positions for each motor. For a 6-DOF arm with a gripper, a joint-space action might look like:
```
action = [shoulder_pan: 45.0, shoulder_lift: -20.0, elbow: -30.0, wrist_pitch: 10.0, wrist_roll: 0.0, wrist_yaw: 5.0, gripper: 0.8]
```
Joint space is the default in LeRobot. It is simple, requires no kinematics model, and maps directly to motor commands. Most beginner setups (SO-100, Koch) use joint-space actions.
### End-Effector (EE) Space
End-effector-space actions specify the desired position and orientation of the robot's tool tip (gripper) in Cartesian coordinates:
```
action = [x: 0.25, y: -0.10, z: 0.15, wx: 0.0, wy: 0.0, wz: 0.1, gripper: 0.8]
```
EE space is more intuitive for tasks like pick-and-place because it directly describes where the gripper should go, but it requires a kinematics model (URDF) to convert between EE poses and joint angles.
### Converting Between Spaces
LeRobot provides processor steps for converting between joint and EE spaces using forward and inverse kinematics. These are built on top of `RobotKinematics`, which loads a URDF model of your robot.
```python
from lerobot.model.kinematics import RobotKinematics
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
kinematics = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=["shoulder", "elbow", "wrist_pitch", "wrist_roll", "wrist_yaw"],
)
# Joints → EE (for observations: "where is my gripper?")
fk_step = ForwardKinematicsJointsToEE(kinematics=kinematics, motor_names=[...])
# EE → Joints (for actions: "move my gripper here")
ik_step = InverseKinematicsEEToJoints(kinematics=kinematics, motor_names=[...])
```
See [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) for a complete working example of recording, replaying, and evaluating with EE-space actions on an SO-100 arm.
## Absolute, Relative, and Delta Actions
Regardless of whether you work in joint space or EE space, the action values can be expressed in three different ways. The terminology follows [UMI (Chi et al., 2024)](https://arxiv.org/abs/2402.10329).
### Absolute Actions (LeRobot default)
Each action specifies the target position directly.
**Example** (joint space, chunk of 4):
```
current_state = [45.0, -30.0, 10.0]
action_chunk = [
[46.0, -29.0, 11.0], # go to 46, -29, 11
[47.5, -27.0, 12.0], # go to 47.5, -27, 12
[49.0, -25.0, 13.5], # go to 49, -25, 13.5
[50.0, -24.0, 15.0], # go to 50, -24, 15
]
```
Each value is a target position in the robot's coordinate frame. Simple and direct, but requires a consistent global coordinate frame. This is the default in LeRobot.
### Relative Actions (used by OpenPI / pi0)
Each action in the chunk is an offset from the **current state at the moment of prediction**. All actions in the chunk share the same reference point:
```
current_state = [45.0, -30.0, 10.0]
relative_chunk = [
[1.0, 1.0, 1.0], # +1 from current → target 46, -29, 11
[2.5, 3.0, 2.0], # +2.5 from current → target 47.5, -27, 12
[4.0, 5.0, 3.5], # +4 from current → target 49, -25, 13.5
[5.0, 6.0, 5.0], # +5 from current → target 50, -24, 15
]
```
The conversion is straightforward: `relative = absolute - current_state`. To recover absolute: `absolute = relative + current_state`.
**Why use relative actions?** The model learns to predict offsets centered around zero, which is easier to normalize and leads to more stable training. Because every chunk references the same current state, there is no error accumulation across chunks.
### Delta Actions (sequential differences)
Each action is an offset from the **previous action** (or from the current state for the first step):
```
current_state = [45.0, -30.0, 10.0]
delta_chunk = [
[1.0, 1.0, 1.0], # current → 46, -29, 11
[1.5, 2.0, 1.0], # previous action → 47.5, -27, 12
[1.5, 2.0, 1.5], # previous action → 49, -25, 13.5
[1.0, 1.0, 1.5], # previous action → 50, -24, 15
]
```
Here each step is relative to the one before it. To recover absolute positions you must sum all previous deltas, which means errors accumulate over time. UMI explicitly argues against this representation for this reason.
### Visual Comparison
The figure below (based on a figure from [UMI, Chi et al., 2024](https://arxiv.org/abs/2402.10329)) illustrates the key difference. With **relative trajectory**, every action in the chunk points back to the same origin (current state), so a new inference step cleanly resets the reference. With **delta**, each action depends on the previous one, so errors accumulate. **Absolute** actions require a consistent global coordinate frame.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/action_representations_umi.png"
alt="Relative Trajectory as Action Representation (UMI, Chi et al., 2024)"
width="85%"
/>
## Using Relative Actions in LeRobot
LeRobot provides `RelativeActionsProcessorStep` to convert between absolute and relative actions inside the processor pipeline. This is how pi0, pi0.5, and pi0_fast support relative actions.
> **Note:** All pi models (pi0, pi0.5, pi0*fast) apply relative conversion \_before* normalization (`relative → normalize`), so the normalizer always sees delta (relative) values. This means **relative action stats are required** for all of them when training with `use_relative_actions=true`. In pi0_fast the `RelativeActionsProcessorStep` only modifies the action — the state observation is unchanged — so `NormalizerProcessorStep` still runs before the state tokenizer and the tokenizer continues to receive normalized state as expected.
### How it works
During **training** (preprocessing), actions are converted from absolute to relative before the model sees them:
```
raw absolute action → RelativeActionsProcessorStep → normalize → model
```
During **inference** (postprocessing), model predictions are converted back to absolute before being sent to the robot:
```
model output → unnormalize → AbsoluteActionsProcessorStep → robot
```
The `AbsoluteActionsProcessorStep` reads the cached current state from its paired `RelativeActionsProcessorStep`, so the two must be wired together (handled automatically by the policy factory).
### Enabling relative actions for the pi family (pi0, pi0.5, pi0_fast)
**Step 1**: Precompute relative action statistics for your dataset:
```bash
lerobot-edit-dataset \
--repo_id your_dataset \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']"
```
**Step 2**: Train with relative actions enabled:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
The `relative_exclude_joints` parameter specifies joints that should remain in absolute space. For example, gripper commands are typically binary (open/close) and don't benefit from relative encoding.
### Combining relative actions with RTC
[RTC](https://arxiv.org/abs/2506.07339) runs policy inference at high frequency and sends actions to the robot as they are predicted rather than waiting for a full chunk. Relative actions and RTC are fully compatible: because every chunk in relative mode references the **same** current state (captured at the start of inference), each predicted action in the chunk remains a valid offset even if the robot has already moved. No special handling is needed — `RelativeActionsProcessorStep` caches the state once per inference call and `AbsoluteActionsProcessorStep` applies it to every action in the streamed output.
### Combining relative actions with EE space
Relative actions work in both joint space and EE space. For example, if your dataset stores EE actions, relative encoding converts them to offsets from the current EE pose:
```
current_ee_state = [x: 0.25, y: -0.10, z: 0.15, gripper: 0.8]
absolute_ee_chunk = [
[0.26, -0.09, 0.16, 0.8],
[0.28, -0.07, 0.18, 0.8],
]
relative_ee_chunk = [
[0.01, 0.01, 0.01, 0.0], # offset from current EE pose
[0.03, 0.03, 0.03, 0.0], # offset from current EE pose
]
```
## Processing Pipeline Summary
Here is how the different processors compose. Each arrow is a processor step, and they can be chained in a `RobotProcessorPipeline` or `PolicyProcessorPipeline`:
```
┌─────────────────────────────────────────┐
Action Space │ Joint Space ←──IK──→ EE Space │
│ ForwardKinematicsJointsToEE │
│ InverseKinematicsEEToJoints │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Representation │ Absolute ←────→ Relative │
│ RelativeActionsProcessorStep (pre) │
│ AbsoluteActionsProcessorStep (post) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Normalization │ Raw ←────→ Normalized │
│ NormalizerProcessorStep (pre) │
│ UnnormalizerProcessorStep (post) │
└─────────────────────────────────────────┘
```
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
## References
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
- [Introduction to Processors](./introduction_processors) - How processor pipelines work in LeRobot.
- [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) - Complete example of recording and evaluating with EE-space actions.

View File

@@ -0,0 +1,320 @@
# Adding a New Benchmark
This guide walks you through adding a new simulation benchmark to LeRobot. Follow the steps in order and use the existing benchmarks as templates.
A benchmark in LeRobot is a set of [Gymnasium](https://gymnasium.farama.org/) environments that wrap a third-party simulator (like LIBERO or Meta-World) behind a standard `gym.Env` interface. The `lerobot-eval` CLI then runs evaluation uniformly across all benchmarks.
## Existing benchmarks at a glance
Before diving in, here is what is already integrated:
| Benchmark | Env file | Config class | Tasks | Action dim | Processor |
| -------------- | ------------------- | ------------------ | ------------------- | ------------ | ---------------------------- |
| LIBERO | `envs/libero.py` | `LiberoEnv` | 130 across 5 suites | 7 | `LiberoProcessorStep` |
| Meta-World | `envs/metaworld.py` | `MetaworldEnv` | 50 (MT50) | 4 | None |
| IsaacLab Arena | Hub-hosted | `IsaaclabArenaEnv` | Configurable | Configurable | `IsaaclabArenaProcessorStep` |
Use `src/lerobot/envs/libero.py` and `src/lerobot/envs/metaworld.py` as reference implementations.
## How it all fits together
### Data flow
During evaluation, data moves through four stages:
```
1. gym.Env ──→ raw observations (numpy dicts)
2. Preprocessing ──→ standard LeRobot keys + task description
(preprocess_observation, add_envs_task in envs/utils.py)
3. Processors ──→ env-specific then policy-specific transforms
(env_preprocessor, policy_preprocessor)
4. Policy ──→ select_action() ──→ action tensor
then reverse: policy_postprocessor → env_postprocessor → numpy action → env.step()
```
Most benchmarks only need to care about stage 1 (producing observations in the right format) and optionally stage 3 (if env-specific transforms are needed).
### Environment structure
`make_env()` returns a nested dict of vectorized environments:
```python
dict[str, dict[int, gym.vector.VectorEnv]]
# ^suite ^task_id
```
A single-task env (e.g. PushT) looks like `{"pusht": {0: vec_env}}`.
A multi-task benchmark (e.g. LIBERO) looks like `{"libero_spatial": {0: vec0, 1: vec1, ...}, ...}`.
### How evaluation runs
All benchmarks are evaluated the same way by `lerobot-eval`:
1. `make_env()` builds the nested `{suite: {task_id: VectorEnv}}` dict.
2. `eval_policy_all()` iterates over every suite and task.
3. For each task, it runs `n_episodes` rollouts via `rollout()`.
4. Results are aggregated hierarchically: episode, task, suite, overall.
5. Metrics include `pc_success` (success rate), `avg_sum_reward`, and `avg_max_reward`.
The critical piece: your env must return `info["is_success"]` on every `step()` call. This is how the eval loop knows whether a task was completed.
## What your environment must provide
LeRobot does not enforce a strict observation schema. Instead it relies on a set of conventions that all benchmarks follow.
### Env attributes
Your `gym.Env` must set these attributes:
| Attribute | Type | Why |
| -------------------- | ----- | ---------------------------------------------------- |
| `_max_episode_steps` | `int` | `rollout()` uses this to cap episode length |
| `task_description` | `str` | Passed to VLA policies as a language instruction |
| `task` | `str` | Fallback identifier if `task_description` is not set |
### Success reporting
Your `step()` and `reset()` must include `"is_success"` in the `info` dict:
```python
info = {"is_success": True} # or False
return observation, reward, terminated, truncated, info
```
### Observations
The simplest approach is to map your simulator's outputs to the standard keys that `preprocess_observation()` already understands. Do this inside your `gym.Env` (e.g. in a `_format_raw_obs()` helper):
| Your env should output | LeRobot maps it to | What it is |
| ------------------------- | -------------------------- | ------------------------------------- |
| `"pixels"` (single array) | `observation.image` | Single camera image, HWC uint8 |
| `"pixels"` (dict) | `observation.images.<cam>` | Multiple cameras, each HWC uint8 |
| `"agent_pos"` | `observation.state` | Proprioceptive state vector |
| `"environment_state"` | `observation.env_state` | Full environment state (e.g. PushT) |
| `"robot_state"` | `observation.robot_state` | Nested robot state dict (e.g. LIBERO) |
If your simulator uses different key names, you have two options:
1. **Recommended:** Rename them to the standard keys inside your `gym.Env` wrapper.
2. **Alternative:** Write an env processor to transform observations after `preprocess_observation()` runs (see step 4 below).
### Actions
Actions are continuous numpy arrays in a `gym.spaces.Box`. The dimensionality depends on your benchmark (7 for LIBERO, 4 for Meta-World, etc.). Policies adapt to different action dimensions through their `input_features` / `output_features` config.
### Feature declaration
Each `EnvConfig` subclass declares two dicts that tell the policy what to expect:
- `features` — maps feature names to `PolicyFeature(type, shape)` (e.g. action dim, image shape).
- `features_map` — maps raw observation keys to LeRobot convention keys (e.g. `"agent_pos"` to `"observation.state"`).
## Step by step
<Tip>
At minimum, you need two files: a **gym.Env wrapper** and an **EnvConfig
subclass** with a `create_envs()` override. Everything else is optional or
documentation. No changes to `factory.py` are needed.
</Tip>
### Checklist
| File | Required | Why |
| ---------------------------------------- | -------- | ------------------------------------------------------------ |
| `src/lerobot/envs/<benchmark>.py` | Yes | Wraps the simulator as a standard gym.Env |
| `src/lerobot/envs/configs.py` | Yes | Registers your benchmark and its `create_envs()` for the CLI |
| `src/lerobot/processor/env_processor.py` | Optional | Custom observation/action transforms |
| `src/lerobot/envs/utils.py` | Optional | Only if you need new raw observation keys |
| `pyproject.toml` | Yes | Declares benchmark-specific dependencies |
| `docs/source/<benchmark>.mdx` | Yes | User-facing documentation page |
| `docs/source/_toctree.yml` | Yes | Adds your page to the docs sidebar |
### 1. The gym.Env wrapper (`src/lerobot/envs/<benchmark>.py`)
Create a `gym.Env` subclass that wraps the third-party simulator:
```python
class MyBenchmarkEnv(gym.Env):
metadata = {"render_modes": ["rgb_array"], "render_fps": <fps>}
def __init__(self, task_suite, task_id, ...):
super().__init__()
self.task = <task_name_string>
self.task_description = <natural_language_instruction>
self._max_episode_steps = <max_steps>
self.observation_space = spaces.Dict({...})
self.action_space = spaces.Box(low=..., high=..., shape=(...,), dtype=np.float32)
def reset(self, seed=None, **kwargs):
... # return (observation, info) — info must contain {"is_success": False}
def step(self, action: np.ndarray):
... # return (obs, reward, terminated, truncated, info) — info must contain {"is_success": <bool>}
def render(self):
... # return RGB image as numpy array
def close(self):
...
```
Also provide a factory function that returns the nested dict structure:
```python
def create_mybenchmark_envs(
task: str,
n_envs: int,
gym_kwargs: dict | None = None,
env_cls: type | None = None,
) -> dict[str, dict[int, Any]]:
"""Create {suite_name: {task_id: VectorEnv}} for MyBenchmark."""
...
```
See `create_libero_envs()` (multi-suite, multi-task) and `create_metaworld_envs()` (difficulty-grouped tasks) for reference.
### 2. The config (`src/lerobot/envs/configs.py`)
Register a config dataclass so users can select your benchmark with `--env.type=<name>`. Each config owns its environment creation and processor logic via two methods:
- **`create_envs(n_envs, use_async_envs)`** — Returns `{suite: {task_id: VectorEnv}}`. The base class default uses `gym.make()` for single-task envs. Multi-task benchmarks override this.
- **`get_env_processors()`** — Returns `(preprocessor, postprocessor)`. The base class default returns identity (no-op) pipelines. Override if your benchmark needs observation/action transforms.
```python
@EnvConfig.register_subclass("<benchmark_name>")
@dataclass
class MyBenchmarkEnvConfig(EnvConfig):
task: str = "<default_task>"
fps: int = <fps>
obs_type: str = "pixels_agent_pos"
features: dict[str, PolicyFeature] = field(default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(<action_dim>,)),
})
features_map: dict[str, str] = field(default_factory=lambda: {
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
})
def __post_init__(self):
... # populate features based on obs_type
@property
def gym_kwargs(self) -> dict:
return {"obs_type": self.obs_type, "render_mode": self.render_mode}
def create_envs(self, n_envs: int, use_async_envs: bool = False):
"""Override for multi-task benchmarks or custom env creation."""
from lerobot.envs.<benchmark> import create_<benchmark>_envs
return create_<benchmark>_envs(task=self.task, n_envs=n_envs, ...)
def get_env_processors(self):
"""Override if your benchmark needs observation/action transforms."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
return (
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
```
Key points:
- The `register_subclass` name is what users pass on the CLI (`--env.type=<name>`).
- `features` tells the policy what the environment produces.
- `features_map` maps raw observation keys to LeRobot convention keys.
- **No changes to `factory.py` needed** — the factory delegates to `cfg.create_envs()` and `cfg.get_env_processors()` automatically.
### 3. Env processor (optional — `src/lerobot/processor/env_processor.py`)
Only needed if your benchmark requires observation transforms beyond what `preprocess_observation()` handles (e.g. image flipping, coordinate conversion). Define the processor step here and return it from `get_env_processors()` in your config (see step 2):
```python
@dataclass
@ProcessorStepRegistry.register(name="<benchmark>_processor")
class MyBenchmarkProcessorStep(ObservationProcessorStep):
def _process_observation(self, observation):
processed = observation.copy()
# your transforms here
return processed
def transform_features(self, features):
return features # update if shapes change
def observation(self, observation):
return self._process_observation(observation)
```
See `LiberoProcessorStep` for a full example (image rotation, quaternion-to-axis-angle conversion).
### 4. Dependencies (`pyproject.toml`)
Add a new optional-dependency group:
```toml
mybenchmark = ["my-benchmark-pkg==1.2.3", "lerobot[scipy-dep]"]
```
Pinning rules:
- **Always pin** benchmark packages to exact versions for reproducibility (e.g. `metaworld==3.0.0`).
- **Add platform markers** when needed (e.g. `; sys_platform == 'linux'`).
- **Pin fragile transitive deps** if known (e.g. `gymnasium==1.1.0` for Meta-World).
- **Document constraints** in your benchmark doc page.
Users install with:
```bash
pip install -e ".[mybenchmark]"
```
### 5. Documentation (`docs/source/<benchmark>.mdx`)
Write a user-facing page following the template in the next section. See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for full examples.
### 6. Table of contents (`docs/source/_toctree.yml`)
Add your benchmark to the "Benchmarks" section:
```yaml
- sections:
- local: libero
title: LIBERO
- local: metaworld
title: Meta-World
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: <your_benchmark>
title: <Your Benchmark Name>
title: "Benchmarks"
```
## Verifying your integration
After completing the steps above, confirm that everything works:
1. **Install** — `pip install -e ".[mybenchmark]"` and verify the dependency group installs cleanly.
2. **Smoke test env creation** — call `make_env()` with your config in Python, check that the returned dict has the expected `{suite: {task_id: VectorEnv}}` shape, and that `reset()` returns observations with the right keys.
3. **Run a full eval** — `lerobot-eval --env.type=<name> --env.task=<task> --eval.n_episodes=1 --eval.batch_size=1 --policy.path=<any_compatible_policy>` to exercise the full pipeline end-to-end.
4. **Check success detection** — verify that `info["is_success"]` flips to `True` when the task is actually completed. This is what the eval loop uses to compute success rates.
## Writing a benchmark doc page
Each benchmark `.mdx` page should include:
- **Title and description** — 1-2 paragraphs on what the benchmark tests and why it matters.
- **Links** — paper, GitHub repo, project website (if available).
- **Overview image or GIF.**
- **Available tasks** — table of task suites with counts and brief descriptions.
- **Installation** — `pip install -e ".[<benchmark>]"` plus any extra steps (env vars, system packages).
- **Evaluation** — recommended `lerobot-eval` command with `n_episodes` and `batch_size` for reproducible results. Include single-task and multi-task examples if applicable.
- **Policy inputs and outputs** — observation keys with shapes, action space description.
- **Recommended evaluation episodes** — how many episodes per task is standard.
- **Training** — example `lerobot-train` command.
- **Reproducing published results** — link to pretrained model, eval command, results table (if available).
See `docs/source/libero.mdx` and `docs/source/metaworld.mdx` for complete examples.

View File

@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).

View File

@@ -41,13 +41,15 @@ requires = # your-build-system
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
```python
# configuration_my_custom_policy.py
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("my_custom_policy")
@dataclass
@@ -61,22 +63,56 @@ class MyCustomPolicyConfig(PreTrainedConfig):
hidden_dim: Hidden dimension for the policy network
# Add your policy-specific parameters here
"""
# ...PreTrainedConfig fields...
pass
horizon: int = 50
n_action_steps: int = 50
hidden_dim: int = 256
optimizer_lr: float = 1e-4
optimizer_weight_decay: float = 1e-4
def __post_init__(self):
super().__post_init__()
# Add any validation logic here
if self.n_action_steps > self.horizon:
raise ValueError("n_action_steps cannot exceed horizon")
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
# Implement validation logic for your policy's requirements
pass
if not self.image_features:
raise ValueError("MyCustomPolicy requires at least one image feature.")
if self.action_feature is None:
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self):
return None
@property
def observation_delta_indices(self) -> list[int] | None:
"""Relative timestep offsets the dataset loader provides per observation.
Return `None` for single-frame policies. For temporal policies that consume
multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for
3 past frames at stride 10 and 1 future frame at stride 10.
"""
return None
@property
def action_delta_indices(self) -> list[int]:
"""Relative timestep offsets for the action chunk the dataset loader returns.
"""
return list(range(self.horizon))
@property
def reward_delta_indices(self) -> None:
return None
```
## Step 3: Implement the Policy Class
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
```python
# modeling_my_custom_policy.py
@@ -85,38 +121,74 @@ import torch.nn as nn
from typing import Any
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from .configuration_my_custom_policy import MyCustomPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
config.validate_features() # not called automatically by the base class
self.config = config
self.model = ... # your nn.Module here
def reset(self):
"""Reset episode state."""
...
def get_optim_params(self) -> dict:
"""Return parameters to pass to the optimizer (e.g. with per-group lr/wd)."""
return {"params": self.parameters()}
def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return the full action chunk (B, chunk_size, action_dim) for the current observation."""
...
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return a single action for the current timestep (called at inference)."""
...
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Compute the training loss.
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
timesteps padded because the episode ended before `horizon` steps, you
can exclude those from your loss.
"""
actions = batch[ACTION]
action_is_pad = batch.get("action_is_pad")
...
return {"loss": ...}
```
## Step 4: Add Data Processors
Create processor functions:
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
```python
# processor_my_custom_policy.py
from typing import Any
import torch
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
def make_my_custom_policy_pre_post_processors(
config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessing and postprocessing functions for your policy."""
pass # Define your preprocessing and postprocessing logic here
preprocessor = ... # build your PolicyProcessorPipeline for inputs
postprocessor = ... # build your PolicyProcessorPipeline for outputs
return preprocessor, postprocessor
```
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
## Step 5: Package Initialization
Expose your classes in the package's `__init__.py`:

View File

@@ -204,22 +204,26 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
Your dataset includes:
**Your Actions (2 things)**:
**Your Actions (2 features)**:
- How much you moved forward/backward
- How much you turned left/right
- `linear_velocity`: How much you moved forward/backward
- `angular_velocity`: How much you turned left/right
**Robot Observations (12 things)**:
**Robot Observations (24 features)**:
- Front camera video
- Rear camera video
- Current speed
- Battery level
- Which way the robot is facing
- GPS location (latitude, longitude, signal strength)
- Orientation
- GPS (latitude, longitude, signal strength)
- Network signal strength
- Vibration level
- Lamp status (on/off)
- Lamp state (on/off)
- Accelerometer (x, y, z)
- Gyroscope (x, y, z)
- Magnetometer (x, y, z)
- Wheel RPMs (4 wheels)
### Where Your Data Goes

View File

@@ -151,7 +151,7 @@ observation = {
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
The `make_env_pre_post_processors` function delegates to `env_cfg.get_env_processors()`:
```python
from lerobot.envs.factory import make_env_pre_post_processors
@@ -159,47 +159,31 @@ from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg, policy_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg, policy_cfg)
```
### Implementation in `envs/factory.py`
### How It Works
Each `EnvConfig` subclass can override `get_env_processors()` to return benchmark-specific
processor pipelines. The base class returns identity (no-op) processors by default.
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
# In your EnvConfig subclass:
def get_env_processors(self):
from lerobot.processor.pipeline import PolicyProcessorPipeline
return (
PolicyProcessorPipeline(steps=[MyProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
```
The factory function `make_env_pre_post_processors` simply delegates to this method,
with a special case for `XVLAConfig` policies which override the env processors entirely.
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:

View File

@@ -0,0 +1,269 @@
# Human-In-the-Loop Data Collection
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data (recovery movements and corrections) is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
---
## Why Human-In-the-Loop?
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
- Running the trained policy on the real robot
- Having a human intervene when the robot is about to fail
- Recording the human's recovery and correction as training data
- Fine-tuning the policy on the combined dataset
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
---
## How It Works
During a HIL session, the human operator follows this loop within each episode:
1. **Watch** the policy run autonomously
2. **Pause** when failure is imminent, the robot holds its position
3. **Take control** and teleoperate the robot back to a good state (recovery), then correct the behavior
4. **Return control to the policy**, the policy resumes autonomous execution
5. Repeat steps 24 as many times as needed during the episode
6. **End the episode** when the task is complete, save and move on to the next rollout
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Each round targets the current policy's failure modes.
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Policy v0 (trained on demos) │
│ ↓ │
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
│ ↓ │
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
│ ↓ │
│ ... (repeat until satisfactory performance) │
└─────────────────────────────────────────────────────────────────────────┘
```
---
## Hardware Requirements
### Teleoperator Requirements
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators in the current `examples/hil` scripts:**
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
## Script
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
| Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
---
## Step-by-Step Guide
### Step 1: Pre-train a Base Policy
First, train a policy on your demonstration dataset:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/demo-dataset \
--policy.type=pi0 \
--output_dir=outputs/pretrain \
--batch_size=32 \
--steps=50000
```
### Step 2: Collect HIL Data
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/hil/hil_data_collection.py \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--interpolation_multiplier=2
```
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
For models with high inference latency, enable RTC for smooth execution:
```bash
python examples/hil/hil_data_collection.py \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--interpolation_multiplier=3
```
**Controls (Conceptual):**
The interaction model is:
- **Pause input**: pause autonomous policy execution
- **Takeover input**: transfer control to the human operator and record intervention data
- **Return-to-policy input**: hand control back to the policy and continue the same episode
- **Episode control inputs**: save/re-record/stop/reset as needed
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
**The HIL Protocol:**
1. Watch the policy run autonomously (teleop is idle/free)
2. When you see imminent failure, trigger the **pause input**
- Policy stops
- Teleoperator moves to match robot position (torque enabled)
- No frames recorded during pause
3. Trigger the **takeover input** to take control
- Teleoperator torque disabled, free to move
- **Recovery**: Teleoperate the robot back to a good state
- **Correction**: Correct the behavior
- All movements are recorded
4. Trigger the **return-to-policy input**
- Policy resumes autonomous execution from the current state
- You can intervene again at any time (repeat steps 24)
5. End and save the episode when the task is complete (or episode time limit is reached)
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
7. Start the next episode
**Foot Pedal Setup (Linux):**
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
```bash
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
```
### Step 3: Fine-tune the Policy
Fine-tune on the **combined** dataset (`demo-dataset` + `hil-dataset` merged together):
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/hil-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/hil_finetune \
--steps=20000
```
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
---
## Tips for Effective HIL Collection
### When to Intervene
Intervene when you see:
- Robot about to make an irreversible mistake
- Robot hesitating or showing uncertain behavior
- Robot deviating from the expected trajectory
### Recovery: Teleoperating Back to a Good State
During recovery, teleoperate the robot back to a state where:
- The robot is in a familiar, in-distribution configuration
- The current subtask can still be completed
- The recovery trajectory itself is informative training data
### Quality of Corrections
During correction:
- Provide **confident, clean** trajectories
- Complete the current subtask fully
- Don't overcorrect or add unnecessary movements
---
## Related Work
This HIL data collection approach builds on ideas from interactive imitation learning:
- **DAgger** (Ross et al., 2011) introduced the core idea: instead of only training on expert demonstrations, query the expert for corrections on states the _learner_ visits. This breaks the compounding-error cycle of standard behavioral cloning by iteratively collecting on-policy data.
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
```bibtex
@article{ross2011dagger,
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
year={2011}
}
@article{kelly2019hgdagger,
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
journal={arXiv preprint arXiv:1810.02890},
year={2019}
}
@article{hu2025rac,
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
journal={arXiv preprint arXiv:2509.07953},
year={2025}
}
@article{pi2025recap,
title={π0.6: a VLA That Learns From Experience},
author={Physical Intelligence},
year={2025}
}
```

View File

@@ -424,7 +424,7 @@ robot = SO100Follower(robot_config)
robot.connect()
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[episode_idx])
actions = dataset.hf_dataset.select_columns("action")
actions = dataset.select_columns("action")
log_say(f"Replaying episode {episode_idx}")
for idx in range(dataset.num_frames):

View File

@@ -1,36 +1,61 @@
# LIBERO
**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots wont just be pretrained once in a factory, theyll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and its a key step toward building robots that become truly personalized helpers.
LIBERO is a benchmark designed to study **lifelong robot learning** — the idea that robots need to keep learning and adapting with their users over time, not just be pretrained once. It provides a set of standardized manipulation tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other's work.
- 📄 [LIBERO paper](https://arxiv.org/abs/2306.03310)
- 💻 [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO)
To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each others work.
LIBERO includes **five task suites**:
- **LIBERO-Spatial (`libero_spatial`)** tasks that require reasoning about spatial relations.
- **LIBERO-Object (`libero_object`)** tasks centered on manipulating different objects.
- **LIBERO-Goal (`libero_goal`)** goal-conditioned tasks where the robot must adapt to changing targets.
- **LIBERO-90 (`libero_90`)** 90 short-horizon tasks from the LIBERO-100 collection.
- **LIBERO-Long (`libero_10`)** 10 long-horizon tasks from the LIBERO-100 collection.
Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
- Paper: [Benchmarking Knowledge Transfer for Lifelong Robot Learning](https://arxiv.org/abs/2306.03310)
- GitHub: [Lifelong-Robot-Learning/LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO)
- Project website: [libero-project.github.io](https://libero-project.github.io)
![An overview of the LIBERO benchmark](https://libero-project.github.io/assets/img/libero/fig1.png)
## Evaluating with LIBERO
## Available tasks
At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model.
LIBERO includes **five task suites** covering **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios:
LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag.
| Suite | CLI name | Tasks | Description |
| -------------- | ---------------- | ----- | -------------------------------------------------- |
| LIBERO-Spatial | `libero_spatial` | 10 | Tasks requiring reasoning about spatial relations |
| LIBERO-Object | `libero_object` | 10 | Tasks centered on manipulating different objects |
| LIBERO-Goal | `libero_goal` | 10 | Goal-conditioned tasks with changing targets |
| LIBERO-90 | `libero_90` | 90 | Short-horizon tasks from the LIBERO-100 collection |
| LIBERO-Long | `libero_10` | 10 | Long-horizon tasks from the LIBERO-100 collection |
To Install LIBERO, after following LeRobot official instructions, just do:
`pip install -e ".[libero]"`
## Installation
After following the LeRobot installation instructions:
```bash
pip install -e ".[libero]"
```
<Tip>
LIBERO requires Linux (`sys_platform == 'linux'`). LeRobot uses MuJoCo for simulation — set the rendering backend before training or evaluation:
```bash
export MUJOCO_GL=egl # for headless servers (HPC, cloud)
```
</Tip>
## Evaluation
### Default evaluation (recommended)
Evaluate across the four standard suites (10 episodes per task):
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--env.max_parallel_tasks=1
```
### Single-suite evaluation
Evaluate a policy on one LIBERO suite:
Evaluate on one LIBERO suite:
```bash
lerobot-eval \
@@ -42,15 +67,13 @@ lerobot-eval \
```
- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite.
- `--env.task_ids` restricts to specific task indices (`[0]`, `[1,2,3]`, etc.). Omit to run all tasks in the suite.
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run in total.
---
- `--eval.n_episodes` sets how many episodes to run per task.
### Multi-suite evaluation
Benchmark a policy across multiple suites at once:
Benchmark a policy across multiple suites at once by passing a comma-separated list:
```bash
lerobot-eval \
@@ -61,50 +84,49 @@ lerobot-eval \
--eval.n_episodes=2
```
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
### Control mode
### Control Mode
LIBERO supports two control modes — `relative` (default) and `absolute`. Different VLA checkpoints are trained with different action parameterizations, so make sure the mode matches your policy:
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
```bash
--env.control_mode=relative # or "absolute"
```
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
**Observations:**
- **Observations**
- `observation.state` proprioceptive features (agent state).
- `observation.images.image` main camera view (`agentview_image`).
- `observation.images.image2` wrist camera view (`robot0_eye_in_hand_image`).
- `observation.state` — 8-dim proprioceptive features (eef position, axis-angle orientation, gripper qpos)
- `observation.images.image` — main camera view (`agentview_image`), HWC uint8
- `observation.images.image2` — wrist camera view (`robot0_eye_in_hand_image`), HWC uint8
⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation.
If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer.
This will be fixed with the upcoming Pipeline PR.
<Tip warning={true}>
LeRobot enforces the `.images.*` prefix for visual features. Ensure your
policy config `input_features` use the same naming keys, and that your dataset
metadata keys follow this convention. If your data contains different keys,
you must rename the observations to match what the policy expects, since
naming keys are encoded inside the normalization statistics layer.
</Tip>
- **Actions**
- Continuous control values in a `Box(-1, 1, shape=(7,))` space.
**Actions:**
We also provide a notebook for quick testing:
Training with LIBERO
- Continuous control in `Box(-1, 1, shape=(7,))` — 6D end-effector delta + 1D gripper
## Training with LIBERO
### Recommended evaluation episodes
When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
For reproducible benchmarking, use **10 episodes per task** across all four standard suites (Spatial, Object, Goal, Long). This gives 400 total episodes and matches the protocol used for published results.
The environment expects:
## Training
- `observation.state` → 8-dim agent state
- `observation.images.image` → main camera (`agentview_image`)
- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
### Dataset
⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code.
To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation:
👉 [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
We provide a preprocessed LIBERO dataset fully compatible with LeRobot:
For reference, here is the **original dataset** published by Physical Intelligence:
👉 [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
- [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
---
For reference, the original dataset published by Physical Intelligence:
- [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
### Example training command
@@ -121,52 +143,39 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval_freq=1000 \
--eval_freq=1000
```
---
## Reproducing published results
### Note on rendering
We reproduce the results of Pi0.5 on the LIBERO benchmark. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
The finetuned model: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
## Reproducing π₀.₅ results
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
The finetuned model can be found here:
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
### Evaluation command
```bash
lerobot-eval \
--output_dir=/logs/ \
--output_dir=./eval_logs/ \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--policy.path=pi05_libero_finetuned \
--policy.n_action_steps=10 \
--output_dir=./eval_logs/ \
--env.max_parallel_tasks=1
```
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
We set `n_action_steps=10`, matching the original OpenPI implementation.
### Results
We obtain the following results on the LIBERO benchmark:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| ------------------- | -------------- | ------------- | ----------- | --------- | -------- |
| **Pi0.5 (LeRobot)** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
These results are consistent with the [original results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| ------------------ | -------------- | ------------- | ----------- | --------- | --------- |
| **Pi0.5 (OpenPI)** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |

View File

@@ -1,32 +1,111 @@
# Meta-World
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
Meta-World is an open-source simulation benchmark for **multi-task and meta reinforcement learning** in continuous-control robotic manipulation. It bundles 50 diverse manipulation tasks using everyday objects and a common tabletop Sawyer arm, providing a standardized playground to test whether algorithms can learn many different tasks and generalize quickly to new ones.
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
- Paper: [Meta-World: A Benchmark and Evaluation for Multi-Task and Meta Reinforcement Learning](https://arxiv.org/abs/1910.10897)
- GitHub: [Farama-Foundation/Metaworld](https://github.com/Farama-Foundation/Metaworld)
- Project website: [metaworld.farama.org](https://metaworld.farama.org)
![MetaWorld MT10 demo](https://meta-world.github.io/figures/ml45.gif)
## Why Meta-World matters
## Available tasks
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
Meta-World provides 50 tasks organized into difficulty groups. In LeRobot, you can evaluate on individual tasks, difficulty groups, or the full MT50 suite:
## What it enables in LeRobot
| Group | CLI name | Tasks | Description |
| ---------- | -------------------- | ----- | ------------------------------------------------------ |
| Easy | `easy` | 28 | Tasks with simple dynamics and single-step goals |
| Medium | `medium` | 11 | Tasks requiring multi-step reasoning |
| Hard | `hard` | 6 | Tasks with complex contacts and precise manipulation |
| Very Hard | `very_hard` | 5 | The most challenging tasks in the suite |
| MT50 (all) | Comma-separated list | 50 | All 50 tasks — the most challenging multi-task setting |
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
You can also pass individual task names directly (e.g., `assembly-v3`, `dial-turn-v3`).
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
We provide a LeRobot-ready dataset for Meta-World MT50 on the HF Hub: [lerobot/metaworld_mt50](https://huggingface.co/datasets/lerobot/metaworld_mt50). This dataset is formatted for the MT50 evaluation that uses all 50 tasks with fixed object/goal positions and one-hot task vectors for consistency.
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
## Installation
## Quick start, train a SmolVLA policy on Meta-World
After following the LeRobot installation instructions:
Example command to train a SmolVLA policy on a subset of tasks:
```bash
pip install -e ".[metaworld]"
```
<Tip warning={true}>
If you encounter an `AssertionError: ['human', 'rgb_array', 'depth_array']` when running Meta-World environments, this is a mismatch between Meta-World and your Gymnasium version. Fix it with:
```bash
pip install "gymnasium==1.1.0"
```
</Tip>
## Evaluation
### Default evaluation (recommended)
Evaluate on the medium difficulty split (a good balance of coverage and compute):
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=medium \
--eval.batch_size=1 \
--eval.n_episodes=10
```
### Single-task evaluation
Evaluate on a specific task:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=assembly-v3 \
--eval.batch_size=1 \
--eval.n_episodes=10
```
### Multi-task evaluation
Evaluate across multiple tasks or difficulty groups:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
--eval.batch_size=1 \
--eval.n_episodes=10
```
- `--env.task` accepts explicit task lists (comma-separated) or difficulty groups (e.g., `easy`, `medium`, `hard`, `very_hard`).
- `--eval.batch_size` controls how many environments run in parallel.
- `--eval.n_episodes` sets how many episodes to run per task.
### Policy inputs and outputs
**Observations:**
- `observation.image` — single camera view (`corner2`), 480x480 HWC uint8
- `observation.state` — 4-dim proprioceptive state (end-effector position + gripper)
**Actions:**
- Continuous control in `Box(-1, 1, shape=(4,))` — 3D end-effector delta + 1D gripper
### Recommended evaluation episodes
For reproducible benchmarking, use **10 episodes per task**. For the full MT50 suite this gives 500 total episodes. If you care about generalization, run on the full MT50 — it is intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
## Training
### Example training command
Train a SmolVLA policy on a subset of Meta-World tasks:
```bash
lerobot-train \
@@ -44,37 +123,8 @@ lerobot-train \
--eval_freq=1000
```
Notes:
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
- **Gymnasium Assertion Error**: if you encounter an error like
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
We recommend using:
```bash
pip install "gymnasium==1.1.0"
```
to ensure proper compatibility.
## Quick start — evaluate a trained policy
To evaluate a trained policy on the Meta-World medium difficulty split:
```bash
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=metaworld \
--env.task=medium \
--eval.batch_size=1 \
--eval.n_episodes=2
```
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
## Practical tips
- If you care about generalization, run on the full MT50 suite — its intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.

View File

@@ -0,0 +1,340 @@
# Multitask DiT Policy
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
## Model Overview
The model uses:
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.
## Installation Requirements
Multitask DiT Policy has additional dependencies. Install it with:
```bash
pip install lerobot[multi_task_dit]
```
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
## Usage
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
```python
policy.type=multi_task_dit
```
## Training
### Basic Training Command
Here's a complete training command for training Multitask DiT on your dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
For reliable performance, start with these suggested default hyperparameters:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
**Key Parameters:**
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task
### Training Configuration Parameters
#### Objective Selection
Choose between diffusion and flow matching:
```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```
#### Transformer Architecture
Adjust model capacity based on dataset size:
```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```
**Positional Encoding Options:**
The model supports two positional encoding methods for action sequences:
```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE
# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```
**Other Transformer Parameters:**
```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```
#### Vision Encoder Configuration
```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```
#### Text Encoder Configuration
```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```
#### Learning Rate Configuration
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```
### Training Tuning Guidelines
#### 1. Flow Matching with Beta Sampling
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
Consider testing the flow-matching objective and evaluating performance differences for your task:
```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
#### 2. Number of Transformer Layers
Match model capacity to your dataset size:
- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers
#### 3. `horizon` Tuning
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
#### 4. `n_action_steps` Sensitivity
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
### Inference Tuning
For faster inference, use DDIM with fewer sampling steps:
```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```
### Resuming Training
To resume training from a checkpoint:
```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
## Common Failure Modes and Debugging
Training these models can be finicky. Here are common failure modes and debugging approaches:
### Idling / No Motion
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
**Debugging tips:**
- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction
### Executing the Wrong Task
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
**Potential causes:**
- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset
**Debugging tips:**
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning
### Training Instability
If training loss is unstable or diverging:
- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly
## Performance Considerations
### GPU Requirements
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
### Batch Size Recommendations
- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)
## Example: Training on Custom Dataset
Here's a complete example training on a custom dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```
## References
For more details on the technical implementation and architecture, see:
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)

View File

@@ -91,6 +91,46 @@ lerobot-train \
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
## Relative Actions
By default, π₀ predicts absolute actions. You can enable **relative actions** so the model predicts offsets relative to the current robot state. This can improve training stability for certain setups.
To use relative actions, first recompute your dataset stats in relative space via the CLI:
```bash
lerobot-edit-dataset \
--repo_id your_dataset \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--push_to_hub true
```
Or equivalently in Python:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.dataset_tools import recompute_stats
dataset = LeRobotDataset("your_dataset")
recompute_stats(dataset, relative_action=True, chunk_size=50, relative_exclude_joints=["gripper"])
dataset.push_to_hub()
```
The `chunk_size` should match your policy's `chunk_size` (default 50 for π₀). `relative_exclude_joints` lists joint names that should remain in absolute space (e.g. gripper commands). Use `--push_to_hub true` to upload the updated stats to the Hub.
Then train with relative actions enabled:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]' \
...
```
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

View File

@@ -97,6 +97,46 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
## Relative Actions
By default, π₀.₅ predicts absolute actions. You can enable **relative actions** so the model predicts offsets relative to the current robot state. This can improve training stability for certain setups.
To use relative actions, first recompute your dataset stats in relative space via the CLI:
```bash
lerobot-edit-dataset \
--repo_id your_dataset \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--push_to_hub true
```
Or equivalently in Python:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.dataset_tools import recompute_stats
dataset = LeRobotDataset("your_dataset")
recompute_stats(dataset, relative_action=True, chunk_size=50, relative_exclude_joints=["gripper"])
dataset.push_to_hub()
```
The `chunk_size` should match your policy's `chunk_size` (default 50 for π₀.₅). `relative_exclude_joints` lists joint names that should remain in absolute space (e.g. gripper commands). Use `--push_to_hub true` to upload the updated stats to the Hub.
Then train with relative actions enabled:
```bash
lerobot-train \
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]' \
...
```
## Performance Results
### Libero Benchmark Results

View File

@@ -0,0 +1,37 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```

114
docs/source/rename_map.mdx Normal file
View File

@@ -0,0 +1,114 @@
# Rename Map and Empty Cameras
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
## Why observation keys don't always match
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
## Using the rename map
Pass the mapping as a JSON string on the command line. The convention is always:
```
--rename_map='{"source_key": "policy_key", ...}'
```
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
### Training
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
```
### Evaluation
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
```bash
lerobot-eval \
--policy.path=lerobot/pi0fast-libero \
--env.type=libero \
... \
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
```
### Recording
`lerobot-record` also supports rename maps, nested under the dataset config:
```bash
lerobot-record \ # When running inference
--policy.path="<user>/smolVLA_finetuned" \
... \
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
```
## Alternative: edit the policy config directly
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
## Empty cameras: fewer views than the policy expects
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
### How it works
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
```
observation.images.empty_camera_0
observation.images.empty_camera_1
...
```
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
### Example
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
1. Set `--policy.empty_cameras=1`.
2. The config adds a third key: `observation.images.empty_camera_0`.
3. Use the rename map for your two real cameras as usual.
4. The third slot is masked out — no fake images needed in your dataset.
## Quick reference
| Goal | What to do |
| ----------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |

View File

@@ -236,10 +236,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 1
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Place the first motor into the base.
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
- Install both motor horns, securing the top horn with a M3x6mm screw.
- Attach the shoulder part.
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
- Add the shoulder motor holder.
@@ -255,9 +255,9 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 2
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Slide the second motor in from the top.
- Fasten the second motor with 4 M2x6mm screws.
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
- Attach the upper arm with 4 M3x6mm screws on each side.
<div class="video-container">
@@ -271,8 +271,8 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 3
- Insert motor 3 and fasten using 4 M2x6mm screws
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Insert motor 3 and fasten using 4 M2x6mm screws.
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
<div class="video-container">
@@ -286,9 +286,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
### Joint 4
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
- Slide over motor holder 4.
- Slide in motor 4.
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
- Fasten motor 4 with 4 M2x6mm screws.
<div class="video-container">
<video controls width="600">
@@ -321,7 +322,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
- Attach the motor horns and again use a M3x6mm horn screw.
- Install both motor horns on the gripper motor. Secure the top horn with a M3x6mm screw; no screws are required for the bottom horn.
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
<div class="video-container">

View File

@@ -78,7 +78,7 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
robot.connect()
try:

View File

@@ -0,0 +1,680 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
Usage:
python examples/dataset/create_progress_videos.py \
--repo-id lerobot-data-collection/level2_final_quality3 \
--episode 1100
python examples/dataset/create_progress_videos.py \
--repo-id lerobot-data-collection/level2_final_quality3 \
--episode 1100 \
--camera-key observation.images.top \
--output-dir ./my_videos \
--gif
"""
from __future__ import annotations
import argparse
import json
import logging
import subprocess
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
GRAPH_Y_TOP_FRAC = 0.01
GRAPH_Y_BOT_FRAC = 0.99
LINE_THICKNESS = 3
SHADOW_THICKNESS = 6
REF_ALPHA = 0.45
FILL_ALPHA = 0.55
SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"],
ignore_patterns=["*.mp4"],
)
)
return local_path
def load_episode_meta(local_path: Path, episode: int, camera_key: str | None) -> dict:
"""Read info.json and episode parquet to resolve fps, video path, and timestamps.
Args:
local_path: Local cache directory containing meta/.
episode: Episode index to look up.
camera_key: Camera observation key (e.g. "observation.images.base").
If None, the first available video key is used.
Returns:
Dict with keys: fps, camera, video_rel, chunk_index, file_index,
from_ts, to_ts, task_name.
"""
info = json.loads((local_path / "meta" / "info.json").read_text())
fps = info["fps"]
features = info["features"]
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
if not video_keys:
raise RuntimeError("No video keys found in dataset features")
if camera_key is not None:
if camera_key not in video_keys:
raise RuntimeError(f"camera_key='{camera_key}' not found. Available: {video_keys}")
selected_camera = camera_key
else:
selected_camera = video_keys[0]
logging.info(" fps=%d camera='%s' all_cams=%s", fps, selected_camera, video_keys)
episode_rows = []
for parquet_file in sorted((local_path / "meta" / "episodes").glob("**/*.parquet")):
episode_rows.append(pd.read_parquet(parquet_file))
episode_df = pd.concat(episode_rows, ignore_index=True)
row = episode_df[episode_df["episode_index"] == episode]
if row.empty:
raise RuntimeError(f"Episode {episode} not found in episode metadata")
row = row.iloc[0]
chunk_col = f"videos/{selected_camera}/chunk_index"
file_col = f"videos/{selected_camera}/file_index"
ts_from_col = f"videos/{selected_camera}/from_timestamp"
ts_to_col = f"videos/{selected_camera}/to_timestamp"
if chunk_col not in row.index:
chunk_col = f"{selected_camera}/chunk_index"
file_col = f"{selected_camera}/file_index"
ts_from_col = f"{selected_camera}/from_timestamp"
ts_to_col = f"{selected_camera}/to_timestamp"
if chunk_col not in row.index:
raise RuntimeError(
f"Cannot find video metadata columns for {selected_camera}.\nAvailable: {list(row.index)}"
)
chunk_index = int(row[chunk_col])
file_index = int(row[file_col])
from_timestamp = float(row[ts_from_col])
to_timestamp = float(row[ts_to_col])
video_template = info.get(
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
)
video_rel = video_template.format(
video_key=selected_camera,
chunk_index=chunk_index,
file_index=file_index,
)
task_name = _resolve_task_name(row, local_path)
return {
"fps": fps,
"camera": selected_camera,
"video_rel": video_rel,
"chunk_index": chunk_index,
"file_index": file_index,
"from_ts": from_timestamp,
"to_ts": to_timestamp,
"task_name": task_name,
}
def _resolve_task_name(row: pd.Series, local_path: Path) -> str:
"""Best-effort extraction of the task name for an episode row.
Args:
row: Single-episode row from the episodes parquet.
local_path: Dataset cache root.
Returns:
Task name string, or empty string if unavailable.
"""
try:
if "tasks" in row.index and row["tasks"] is not None:
tasks_val = row["tasks"]
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
return str(tasks_val[0])
return str(tasks_val).strip("[]'")
tasks_parquet = local_path / "meta" / "tasks.parquet"
if tasks_parquet.exists():
tasks_df = pd.read_parquet(tasks_parquet)
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
match = tasks_df[tasks_df["task_index"] == task_idx]
if not match.empty:
return str(match.index[0])
except Exception as exc:
logging.warning("Could not load task name: %s", exc)
return ""
def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
"""Download the specific video file if not already cached.
Args:
repo_id: HuggingFace dataset repository ID.
local_path: Local cache directory.
video_rel: Relative path to the video file within the dataset.
Returns:
Absolute path to the downloaded video file.
"""
video_path = local_path / video_rel
if video_path.exists():
logging.info(" Video already cached: %s", video_path)
return video_path
logging.info("[2/4] Downloading video file %s ...", video_rel)
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=str(local_path),
allow_patterns=[video_rel],
)
if not video_path.exists():
raise RuntimeError(f"Video not found after download: {video_path}")
return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / "sarm_progress.parquet"
if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found")
return None
df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode)
return None
episode_df = episode_df.sort_values("frame_index")
if "progress_dense" in episode_df.columns and episode_df["progress_dense"].notna().any():
progress_column = "progress_dense"
elif "progress_sparse" in episode_df.columns:
progress_column = "progress_sparse"
else:
progress_columns = [c for c in episode_df.columns if "progress" in c.lower()]
if not progress_columns:
return None
progress_column = progress_columns[0]
logging.info(" Using progress column: '%s'", progress_column)
return episode_df[["frame_index", progress_column]].rename(columns={progress_column: "progress"}).values
def _precompute_pixel_coords(
progress_data: np.ndarray,
num_frames: int,
frame_width: int,
frame_height: int,
) -> np.ndarray:
"""Map progress samples to pixel coordinates for overlay drawing.
Args:
progress_data: (N, 2) array of (frame_index, progress).
num_frames: Total number of video frames.
frame_width: Video width in pixels.
frame_height: Video height in pixels.
Returns:
(N, 2) array of (x, y) pixel coordinates.
"""
frame_indices = progress_data[:, 0].astype(float)
progress_values = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
y_top = int(frame_height * GRAPH_Y_TOP_FRAC)
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
graph_height = y_bot - y_top
x_coords = (frame_indices / (num_frames - 1) * (frame_width - 1)).astype(int)
y_coords = (y_bot - progress_values * graph_height).astype(int)
return np.stack([x_coords, y_coords], axis=1)
def _progress_color(normalized_position: float) -> tuple[int, int, int]:
"""Interpolate BGR color from red to green based on position in [0, 1].
Args:
normalized_position: Value in [0, 1] indicating how far along the episode.
Returns:
BGR color tuple.
"""
red = int(255 * (1.0 - normalized_position))
green = int(255 * normalized_position)
return (0, green, red)
def _prerender_fill_polygon(
pixel_coords: np.ndarray,
frame_width: int,
frame_height: int,
) -> np.ndarray:
"""Pre-render the grey fill polygon under the progress curve as a BGRA image.
Args:
pixel_coords: (N, 2) array of (x, y) pixel coordinates.
frame_width: Video width in pixels.
frame_height: Video height in pixels.
Returns:
BGRA image array of shape (frame_height, frame_width, 4).
"""
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
fill_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
polygon = np.concatenate(
[
pixel_coords,
[[pixel_coords[-1][0], y_bot], [pixel_coords[0][0], y_bot]],
],
axis=0,
).astype(np.int32)
cv2.fillPoly(fill_image, [polygon], color=(128, 128, 128, int(255 * FILL_ALPHA)))
return fill_image
def _alpha_composite_region(base: np.ndarray, overlay_bgra: np.ndarray, x_limit: int) -> None:
"""Blend BGRA overlay onto BGR base in-place, up to x_limit columns.
Args:
base: BGR frame to draw on (modified in-place).
overlay_bgra: BGRA overlay image.
x_limit: Only blend columns [0, x_limit).
"""
if x_limit <= 0:
return
region_base = base[:, :x_limit]
region_overlay = overlay_bgra[:, :x_limit]
alpha = region_overlay[:, :, 3:4].astype(np.float32) / 255.0
region_base[:] = np.clip(
region_overlay[:, :, :3].astype(np.float32) * alpha + region_base.astype(np.float32) * (1.0 - alpha),
0,
255,
).astype(np.uint8)
def _draw_text_outlined(
frame: np.ndarray,
text: str,
position: tuple[int, int],
font_scale: float,
thickness: int = 1,
) -> None:
"""Draw white text with a dark outline for readability on any background.
Args:
frame: BGR image to draw on (modified in-place).
text: String to render.
position: (x, y) bottom-left corner of the text.
font_scale: OpenCV font scale.
thickness: Text stroke thickness.
"""
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(frame, text, position, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
cv2.putText(frame, text, position, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
def composite_progress_video(
video_path: Path,
from_timestamp: float,
to_timestamp: float,
progress_data: np.ndarray,
output_path: Path,
fps: float,
task_name: str = "",
) -> Path:
"""Read episode frames by seeking into the source video, draw progress overlay, write output.
Uses cv2.CAP_PROP_POS_MSEC to seek directly into the source video,
eliminating the need for an intermediate clip file.
Args:
video_path: Path to the full source video file.
from_timestamp: Start timestamp of the episode in seconds.
to_timestamp: End timestamp of the episode in seconds.
progress_data: (N, 2) array of (frame_index, progress).
output_path: Path to write the output MP4.
fps: Frames per second for the output video.
task_name: Optional task name to display at the top of the video.
Returns:
Path to the written output file (MP4).
"""
capture = cv2.VideoCapture(str(video_path))
try:
capture.set(cv2.CAP_PROP_POS_MSEC, from_timestamp * 1000)
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
duration_seconds = to_timestamp - from_timestamp
num_frames = int(round(duration_seconds * fps))
logging.info(
" Video: %dx%d, %d frames @ %.1f fps (%.2fs)",
frame_width,
frame_height,
num_frames,
fps,
duration_seconds,
)
pixel_coords = _precompute_pixel_coords(progress_data, num_frames, frame_width, frame_height)
y_ref = int(frame_height * GRAPH_Y_TOP_FRAC)
fill_image = _prerender_fill_polygon(pixel_coords, frame_width, frame_height)
ref_line_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
cv2.line(
ref_line_image,
(0, y_ref),
(frame_width - 1, y_ref),
(200, 200, 200, int(255 * REF_ALPHA)),
1,
cv2.LINE_AA,
)
frame_indices = progress_data[:, 0].astype(int)
progress_values = progress_data[:, 1].astype(float)
logging.info("[3/4] Compositing %d frames ...", num_frames)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
for frame_idx in range(num_frames):
ret, frame = capture.read()
if not ret:
break
drawn_count = int(np.searchsorted(frame_indices, frame_idx, side="right"))
x_current = (
int(pixel_coords[min(drawn_count, len(pixel_coords)) - 1][0]) + 1 if drawn_count > 0 else 0
)
_alpha_composite_region(frame, ref_line_image, frame_width)
_alpha_composite_region(frame, fill_image, x_current)
if drawn_count >= 2:
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
line_color = _progress_color(time_position)
points = pixel_coords[:drawn_count].reshape(-1, 1, 2).astype(np.int32)
cv2.polylines(
frame,
[points],
isClosed=False,
color=(255, 255, 255),
thickness=SHADOW_THICKNESS,
lineType=cv2.LINE_AA,
)
cv2.polylines(
frame,
[points],
isClosed=False,
color=line_color,
thickness=LINE_THICKNESS,
lineType=cv2.LINE_AA,
)
if drawn_count > 0:
score = float(progress_values[min(drawn_count, len(progress_values)) - 1])
score_text = f"{score:.2f}"
(text_width, _), _ = cv2.getTextSize(
score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2
)
score_x = frame_width - text_width - 12
score_y = frame_height - 12
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
score_color = _progress_color(time_position)
cv2.putText(
frame,
score_text,
(score_x, score_y),
cv2.FONT_HERSHEY_SIMPLEX,
SCORE_FONT_SCALE,
(0, 0, 0),
4,
cv2.LINE_AA,
)
cv2.putText(
frame,
score_text,
(score_x, score_y),
cv2.FONT_HERSHEY_SIMPLEX,
SCORE_FONT_SCALE,
score_color,
2,
cv2.LINE_AA,
)
if task_name:
(text_width, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
task_x = max((frame_width - text_width) // 2, 4)
_draw_text_outlined(frame, task_name, (task_x, 22), TASK_FONT_SCALE)
writer.write(frame)
if frame_idx % 100 == 0:
logging.info(" Frame %d/%d ...", frame_idx, num_frames)
writer.release()
finally:
capture.release()
logging.info(" MP4 written: %s", output_path)
return output_path
def convert_mp4_to_gif(mp4_path: Path) -> Path:
"""Convert an MP4 to an optimized GIF using ffmpeg palette generation.
Args:
mp4_path: Path to the source MP4 file.
Returns:
Path to the generated GIF file.
"""
capture = cv2.VideoCapture(str(mp4_path))
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
capture.release()
gif_path = mp4_path.with_suffix(".gif")
palette_path = mp4_path.parent / "_palette.png"
logging.info("[4/4] Converting to GIF ...")
result_palette = subprocess.run( # nosec B607
[
"ffmpeg",
"-y",
"-i",
str(mp4_path),
"-vf",
f"fps=10,scale={frame_width}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
"-update",
"1",
str(palette_path),
],
capture_output=True,
text=True,
)
if result_palette.returncode != 0:
logging.warning("palettegen failed:\n%s", result_palette.stderr[-500:])
result_gif = subprocess.run( # nosec B607
[
"ffmpeg",
"-y",
"-i",
str(mp4_path),
"-i",
str(palette_path),
"-filter_complex",
f"fps=10,scale={frame_width}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
str(gif_path),
],
capture_output=True,
text=True,
)
if result_gif.returncode != 0:
logging.warning("GIF encode failed:\n%s", result_gif.stderr[-500:])
palette_path.unlink(missing_ok=True)
logging.info(" GIF written: %s", gif_path)
return gif_path
def process_dataset(
repo_id: str,
episode: int,
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index.
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
Returns:
Path to the final output file, or None on failure.
"""
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
logging.info(" Episode meta: %s", episode_meta)
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode)
if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.")
return None
logging.info(" Progress frames: %d", len(progress_data))
output_path = output_dir / f"{safe_name}_ep{episode}_progress.mp4"
final_path = composite_progress_video(
video_path=video_path,
from_timestamp=episode_meta["from_ts"],
to_timestamp=episode_meta["to_ts"],
progress_data=progress_data,
output_path=output_path,
fps=episode_meta["fps"],
task_name=episode_meta.get("task_name", ""),
)
if create_gif:
final_path = convert_mp4_to_gif(final_path)
logging.info("Done: %s", final_path)
return final_path
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g. 'lerobot-data-collection/level2_final_quality3').",
)
parser.add_argument(
"--episode",
type=int,
required=True,
help="Episode index to visualize.",
)
parser.add_argument(
"--camera-key",
type=str,
default=None,
help="Camera observation key (e.g. 'observation.images.base'). Auto-selects first camera if omitted.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("progress_videos"),
help="Directory to write output files (default: ./progress_videos).",
)
parser.add_argument(
"--gif",
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
args.output_dir.mkdir(parents=True, exist_ok=True)
result = process_dataset(
repo_id=args.repo_id,
episode=args.episode,
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
)
if result:
logging.info("Output: %s", result)
if __name__ == "__main__":
main()

View File

@@ -88,9 +88,8 @@ def main():
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(dataset.hf_dataset)
# You can inspect the dataset using its repr:
print(dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.

File diff suppressed because it is too large Load Diff

228
examples/hil/hil_utils.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
logger.info("[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("[HIL] Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("[HIL] End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("[HIL] Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
logger.info(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
logger.info("[HIL] RESET")
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
logger.info("Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
logger.info("Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
logger.info(
"%s\n Controls:\n"
" SPACE - Pause policy\n"
" c - Take control\n"
" p - Resume policy after pause/correction\n"
" → - End episode\n"
" ESC - Stop and push to hub",
mode,
)

View File

@@ -35,9 +35,7 @@ def main():
# Fetch the dataset to replay
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -48,7 +46,7 @@ def main():
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
# Get recorded action from dataset

View File

@@ -67,9 +67,7 @@ def main():
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -80,7 +78,7 @@ def main():
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
# Get recorded action from dataset

View File

@@ -63,6 +63,31 @@ Usage:
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.left_arm_config.disable_torque_on_disconnect=true \
--robot.left_arm_config.max_relative_target=8.0 \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--robot.right_arm_config.disable_torque_on_disconnect=true \
--robot.right_arm_config.max_relative_target=8.0 \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--interpolation_multiplier=3 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--device=cuda
"""
import logging
@@ -84,24 +109,30 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
)
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.processor.relative_action_processor import to_relative_actions
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
@@ -153,6 +184,7 @@ class RTCDemoConfig(HubMixin):
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
@@ -212,6 +244,35 @@ def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: Tensor,
current_state: Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> Tensor:
"""Convert absolute leftovers into model-space for relative-action RTC policies.
When a policy uses relative actions, the RTC prefix (leftover actions from
the previous chunk) is stored in absolute space. Before feeding it back to
the policy we need to re-express it relative to the *current* robot state
and then re-normalize.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def get_actions(
policy,
robot: RobotWrapper,
@@ -237,7 +298,15 @@ def get_actions(
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
# Only keep .pos joints + camera streams if the policy was trained on positions,
# not the full pos/vel/torque state the robot exposes.
observation_features_hw = {
key: value
for key, value in robot.observation_features().items()
if key.endswith(".pos") or isinstance(value, tuple)
}
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
@@ -255,6 +324,25 @@ def get_actions(
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if relative_step is not None:
if relative_step.action_names is None:
cfg_names = getattr(cfg.policy, "action_feature_names", None)
if cfg_names:
relative_step.action_names = list(cfg_names)
else:
relative_step.action_names = [
k for k in robot.robot.action_features if k.endswith(".pos")
]
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
@@ -297,6 +385,28 @@ def get_actions(
preproceseded_obs = preprocessor(obs_with_policy_features)
# Re-anchor leftover actions for relative-action policies.
# We need the *postprocessed* (absolute) leftover, not the original
# (normalized/relative) one that get_left_over() returns.
if (
prev_actions is not None
and relative_step is not None
and OBS_STATE in obs_with_policy_features
):
with action_queue.lock:
if action_queue.queue is not None:
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
else:
prev_actions_abs = None
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_actions_abs,
current_state=obs_with_policy_features[OBS_STATE],
relative_step=relative_step,
normalizer_step=normalizer_step,
policy_device=policy_device,
)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
@@ -352,21 +462,26 @@ def actor_control(
try:
logger.info("[ACTOR] Starting actor thread")
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
action_interval = 1.0 / cfg.fps
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time

View File

@@ -68,9 +68,7 @@ def main():
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -81,7 +79,7 @@ def main():
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
# Get recorded action from dataset

View File

@@ -99,7 +99,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
@@ -145,6 +145,7 @@ wallx = [
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"lerobot[peft]",

View File

@@ -27,7 +27,8 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are
# looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
@@ -66,7 +67,8 @@ class EvalConfig:
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: int = 50
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True
def __post_init__(self) -> None:
if self.batch_size > self.n_episodes:

View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
__all__ = [
"EpisodeAwareSampler",
"ImageTransforms",
"ImageTransformsConfig",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"StreamingLeRobotDataset",
]

View File

@@ -13,9 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import numpy as np
from lerobot.datasets.io_utils import load_image_as_numpy
from lerobot.utils.constants import ACTION, OBS_STATE
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -624,3 +629,141 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
return aggregated_stats
def _get_valid_chunk_starts(episode_indices: np.ndarray, chunk_size: int) -> np.ndarray:
"""Return all start indices where a chunk of ``chunk_size`` stays within one episode."""
total = len(episode_indices)
if total < chunk_size:
return np.array([], dtype=np.int64)
max_start = total - chunk_size
starts = np.arange(max_start + 1)
valid = episode_indices[starts] == episode_indices[starts + chunk_size - 1]
return starts[valid]
def _compute_relative_chunk_batch(
start_indices: np.ndarray,
all_actions: np.ndarray,
all_states: np.ndarray,
chunk_size: int,
relative_mask: np.ndarray,
) -> np.ndarray:
"""Vectorised relative-action computation for a batch of start indices.
Returns an ``(N * chunk_size, action_dim)`` float32 array.
"""
if len(start_indices) == 0:
return np.empty((0, all_actions.shape[1]), dtype=np.float32)
offsets = np.arange(chunk_size)
frame_idx = start_indices[:, None] + offsets[None, :]
chunks = all_actions[frame_idx].copy()
states = all_states[start_indices]
mask_dim = len(relative_mask)
chunks[:, :, :mask_dim] -= states[:, None, :mask_dim] * relative_mask[None, None, :]
return chunks.reshape(-1, all_actions.shape[1])
def compute_relative_action_stats(
hf_dataset,
features: dict,
chunk_size: int,
exclude_joints: list[str] | None = None,
num_workers: int = 0,
) -> dict[str, np.ndarray]:
"""Compute normalization statistics for relative actions over the full dataset.
Iterates *all* valid action chunks (within single episodes), converts them to
relative actions (action current_state), and computes per-dimension
statistics suitable for normalization.
Args:
hf_dataset: The underlying HuggingFace dataset with "action",
"observation.state", and "episode_index" columns.
features: Dataset feature metadata (must contain "action" with "shape"
and optionally "names").
chunk_size: Number of consecutive frames per action chunk.
exclude_joints: Joint names whose dimensions should remain absolute
(not converted to relative actions).
num_workers: Number of parallel threads for computation. Values ≤1
mean single-threaded. Numpy releases the GIL so threads give
real parallelism here.
Returns:
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
Raises:
ValueError: If the dataset has fewer frames than ``chunk_size``.
RuntimeError: If no valid (single-episode) chunks are found.
"""
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
if exclude_joints is None:
exclude_joints = []
action_dim = features[ACTION]["shape"][0]
action_names = features.get(ACTION, {}).get("names")
mask_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints,
action_names=action_names,
)
relative_mask = np.array(mask_step._build_mask(action_dim), dtype=np.float32)
logging.info("Loading action/state data for relative action stats...")
all_actions = np.array(hf_dataset[ACTION], dtype=np.float32)
all_states = np.array(hf_dataset[OBS_STATE], dtype=np.float32)
episode_indices = np.array(hf_dataset["episode_index"])
valid_starts = _get_valid_chunk_starts(episode_indices, chunk_size)
if len(valid_starts) == 0:
raise RuntimeError(
f"No valid chunks found (total_frames={len(episode_indices)}, chunk_size={chunk_size})"
)
effective_workers = max(num_workers, 1)
logging.info(
f"Computing relative action stats from {len(valid_starts)} chunks "
f"(chunk_size={chunk_size}, workers={effective_workers})"
)
batch_size = 50_000
batches = [valid_starts[i : i + batch_size] for i in range(0, len(valid_starts), batch_size)]
running_stats = RunningQuantileStats()
if num_workers > 1:
from concurrent.futures import ThreadPoolExecutor, as_completed
with ThreadPoolExecutor(max_workers=num_workers) as pool:
futures = [
pool.submit(
_compute_relative_chunk_batch,
batch,
all_actions,
all_states,
chunk_size,
relative_mask,
)
for batch in batches
]
for future in as_completed(futures):
running_stats.update(future.result())
else:
for batch in batches:
running_stats.update(
_compute_relative_chunk_batch(batch, all_actions, all_states, chunk_size, relative_mask)
)
stats = running_stats.get_statistics()
excluded_dims = int(len(relative_mask) - relative_mask.sum())
total_frames = len(valid_starts) * chunk_size
logging.info(
f"Relative action stats ({len(valid_starts)} chunks, {total_frames} frames): "
f"relative_dims={int(relative_mask.sum())}/{len(relative_mask)} (excluded={excluded_dims}), "
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}, "
f"q01={stats['q01'].mean():.4f}, q99={stats['q99'].mean():.4f}"
)
return stats

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from pathlib import Path
import numpy as np
@@ -43,16 +44,24 @@ from lerobot.datasets.utils import (
check_version_compatibility,
flatten_dict,
get_safe_version,
has_legacy_hub_download_metadata,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
"""Metadata container for a LeRobot dataset.
Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and
``episodes/`` parquet files that describe a dataset's structure, content,
and statistics.
"""
def __init__(
self,
repo_id: str,
@@ -61,33 +70,57 @@ class LeRobotDatasetMetadata:
force_cache_sync: bool = False,
metadata_buffer_size: int = 10,
):
"""Load or download metadata for an existing LeRobot dataset.
Attempts to load metadata from local disk. If files are missing or
``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from
the Hub.
Args:
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
root: Local directory for the dataset. When provided, Hub downloads
are materialized directly into this directory. When omitted,
existing local datasets are still looked up under
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
revision: Git revision (branch, tag, or commit hash). Defaults to
the current codebase version.
force_cache_sync: If ``True``, re-download metadata from the Hub
even when local files exist.
metadata_buffer_size: Number of episode metadata records to buffer
in memory before flushing to parquet.
"""
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.writer = None
self._requested_root = Path(root) if root is not None else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self._pq_writer = None
self.latest_episode = None
self.metadata_buffer: list[dict] = []
self.metadata_buffer_size = metadata_buffer_size
self._metadata_buffer: list[dict] = []
self._metadata_buffer_size = metadata_buffer_size
self._finalized = False
try:
if force_cache_sync:
if force_cache_sync or (
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
):
raise FileNotFoundError
self.load_metadata()
self._load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
self._pull_from_repo(allow_patterns="meta/")
self._load_metadata()
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for episode_dict in self._metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
@@ -96,40 +129,50 @@ class LeRobotDatasetMetadata:
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
first_ep = self._metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
table = pa.Table.from_pydict(combined_dict)
if not self.writer:
if not self._pq_writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(
self._pq_writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self.writer.write_table(table)
self._pq_writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
self.latest_episode = self._metadata_buffer[-1]
self._metadata_buffer.clear()
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
self._flush_metadata_buffer()
writer = getattr(self, "writer", None)
writer = getattr(self, "_pq_writer", None)
if writer is not None:
writer.close()
self.writer = None
self._pq_writer = None
def finalize(self) -> None:
"""Flush metadata buffer and close the parquet writer.
Idempotent — safe to call multiple times.
"""
if getattr(self, "_finalized", False):
return
self._close_writer()
self._finalized = True
def __del__(self):
"""
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
"""
self._close_writer()
"""Safety net: flush and close parquet writer on garbage collection."""
# During interpreter shutdown, referenced objects may already be collected.
with contextlib.suppress(Exception):
self.finalize()
def load_metadata(self):
def _load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
@@ -137,22 +180,38 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
def pull_from_repo(
def _pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
if self._requested_root is None:
self.root = Path(
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
cache_dir=HF_LEROBOT_HUB_CACHE,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
return
self._requested_root.mkdir(exist_ok=True, parents=True)
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
local_dir=self._requested_root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
self.root = self._requested_root
@property
def url_root(self) -> str:
"""Hugging Face Hub URL root for this dataset."""
return f"hf://datasets/{self.repo_id}"
@property
@@ -161,6 +220,17 @@ class LeRobotDatasetMetadata:
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
"""Return the relative parquet file path for the given episode index.
Args:
ep_index: Zero-based episode index.
Returns:
Path to the parquet file containing this episode's data.
Raises:
IndexError: If ``ep_index`` is out of range.
"""
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
@@ -174,6 +244,19 @@ class LeRobotDatasetMetadata:
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
"""Return the relative video file path for the given episode and video key.
Args:
ep_index: Zero-based episode index.
vid_key: Feature key identifying the video stream
(e.g. ``'observation.images.laptop'``).
Returns:
Path to the video file containing this episode's frames.
Raises:
IndexError: If ``ep_index`` is out of range.
"""
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
@@ -277,6 +360,17 @@ class LeRobotDatasetMetadata:
return None
def save_episode_tasks(self, tasks: list[str]):
"""Register tasks for the current episode and persist to disk.
New tasks that do not already exist in the dataset are assigned
sequential task indices and appended to the tasks parquet file.
Args:
tasks: List of unique task descriptions in natural language.
Raises:
ValueError: If ``tasks`` contains duplicates.
"""
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
@@ -336,8 +430,8 @@ class LeRobotDatasetMetadata:
latest_path = (
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
if self.writer is None
else self.writer.where
if self._pq_writer is None
else self._pq_writer.where
)
if Path(latest_path).exists():
@@ -359,10 +453,10 @@ class LeRobotDatasetMetadata:
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
# Add to buffer
self.metadata_buffer.append(episode_dict)
self._metadata_buffer.append(episode_dict)
self.latest_episode = episode_dict
if len(self.metadata_buffer) >= self.metadata_buffer_size:
if len(self._metadata_buffer) >= self._metadata_buffer_size:
self._flush_metadata_buffer()
def save_episode(
@@ -373,6 +467,20 @@ class LeRobotDatasetMetadata:
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
"""Persist episode metadata, update dataset info, and aggregate stats.
Writes the episode's metadata to the buffered parquet writer, increments
the total episode/frame counters in ``info.json``, and merges the
episode's statistics into the running dataset statistics.
Args:
episode_index: Zero-based index of the episode being saved.
episode_length: Number of frames in this episode.
episode_tasks: List of task descriptions for this episode.
episode_stats: Per-feature statistics for this episode.
episode_metadata: Additional metadata (chunk/file indices, frame
ranges, video timestamps, etc.).
"""
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
@@ -479,10 +587,36 @@ class LeRobotDatasetMetadata:
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
"""Create metadata for a new LeRobot dataset from scratch.
Initializes the ``info.json`` file on disk with the provided feature
schema and dataset settings. No episode data is written yet.
Args:
repo_id: Repository identifier (e.g. ``'user/my_dataset'``).
fps: Frames per second used during data collection.
features: Feature specification dict mapping feature names to their
type/shape metadata.
robot_type: Optional robot type string stored in metadata.
root: Local directory for the dataset. Defaults to
``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist.
use_videos: If ``True``, visual modalities are encoded as MP4 videos.
metadata_buffer_size: Number of episode metadata records to buffer
before flushing to parquet.
chunks_size: Max number of files per chunk directory. ``None`` uses
the default.
data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the
default.
video_files_size_in_mb: Max video file size in MB. ``None`` uses the
default.
Returns:
A new :class:`LeRobotDatasetMetadata` instance.
"""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj._requested_root = Path(root) if root is not None else None
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
@@ -510,8 +644,9 @@ class LeRobotDatasetMetadata:
)
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
obj._pq_writer = None
obj.latest_episode = None
obj.metadata_buffer = []
obj.metadata_buffer_size = metadata_buffer_size
obj._metadata_buffer = []
obj._metadata_buffer_size = metadata_buffer_size
obj._finalized = False
return obj

View File

@@ -0,0 +1,288 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
from collections.abc import Callable
from pathlib import Path
import datasets
import torch
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
check_delta_timestamps,
get_delta_indices,
get_hf_features_from_features,
)
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
load_nested_dataset,
)
from lerobot.datasets.video_utils import decode_video_frames
class DatasetReader:
"""Encapsulates read-side state and methods for LeRobotDataset.
Owns: hf_dataset, _absolute_to_relative_idx, delta_indices.
"""
def __init__(
self,
meta: LeRobotDatasetMetadata,
root: Path,
episodes: list[int] | None,
tolerance_s: float,
video_backend: str,
delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None,
):
"""Initialize the reader with metadata, filtering, and transform config.
The HF dataset is not loaded here — call :meth:`try_load` or
:meth:`load_and_activate` afterward.
Args:
meta: Dataset metadata instance.
root: Local dataset root directory.
episodes: Optional list of episode indices to select. ``None``
means all episodes.
tolerance_s: Timestamp synchronization tolerance in seconds.
video_backend: Video decoding backend identifier.
delta_timestamps: Optional dict mapping feature keys to lists of
relative timestamp offsets for temporal context windows.
image_transforms: Optional torchvision v2 transform applied to
visual features.
"""
self._meta = meta
self.root = root
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
self._image_transforms = image_transforms
self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None
# Setup delta_indices (doesn't depend on hf_dataset)
self.delta_indices = None
if delta_timestamps is not None:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
self.hf_dataset = self._load_hf_dataset()
except (FileNotFoundError, NotADirectoryError):
self.hf_dataset = None
return False
if not self._check_cached_episodes_sufficient():
self.hf_dataset = None
return False
self._build_index_mapping()
return True
def load_and_activate(self) -> None:
"""Load HF dataset from disk and build index mapping. Call after data is on disk."""
self.hf_dataset = self._load_hf_dataset()
self._build_index_mapping()
def _build_index_mapping(self) -> None:
"""Build absolute-to-relative index mapping from loaded hf_dataset."""
self._absolute_to_relative_idx = None
if self.episodes is not None and self.hf_dataset is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
if self.episodes is not None and self.hf_dataset is not None:
return len(self.hf_dataset)
return self._meta.total_frames
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self._meta.total_episodes
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
for ep_idx in self.hf_dataset.unique("episode_index")
}
if self.episodes is None:
requested_episodes = set(range(self._meta.total_episodes))
else:
requested_episodes = set(self.episodes)
if not requested_episodes.issubset(available_episodes):
return False
if len(self._meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self._meta.video_keys:
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
return True
def get_episodes_file_paths(self) -> list[Path]:
"""Return deduplicated file paths (data + video) for selected episodes.
Used to build the ``allow_patterns`` list for ``snapshot_download``.
"""
episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes))
fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self._meta.video_keys) > 0:
video_files = [
str(self._meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self._meta.video_keys
for ep_idx in episodes
]
fpaths += video_files
# episodes are stored in the same files, so we return unique paths only
fpaths = list(set(fpaths))
return fpaths
def _get_query_indices(
self, abs_idx: int, ep_idx: int
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
"""Compute query indices for delta timestamps."""
ep = self._meta.episodes[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
query_indices = {
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = {
f"{key}_is_pad": torch.BoolTensor(
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
return query_indices, padding
def _get_query_timestamps(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self._meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""Query dataset for indices across keys, skipping video keys."""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self._meta.video_keys:
continue
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault.
"""
ep = self._meta.episodes[ep_idx]
item = {}
for vid_key, query_ts in query_timestamps.items():
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
item[vid_key] = frames.squeeze(0)
return item
def get_item(self, idx) -> dict:
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
``idx`` is a *relative* index into the (possibly episode-filtered)
HF dataset, **not** the absolute frame index stored in the ``index``
column. The absolute index is retrieved from the row itself.
"""
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
abs_idx = item["index"].item()
query_indices = None
if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
if len(self._meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self._image_transforms is not None:
image_keys = self._meta.camera_keys
for cam in image_keys:
item[cam] = self._image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item

View File

@@ -37,7 +37,11 @@ import torch
from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.compute_stats import (
aggregate_stats,
compute_episode_stats,
compute_relative_action_stats,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.io_utils import (
get_parquet_file_size_in_mb,
@@ -56,7 +60,7 @@ from lerobot.datasets.utils import (
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -891,7 +895,7 @@ def _copy_and_reindex_episodes_metadata(
total_frames += src_episode["length"]
dst_meta._close_writer()
dst_meta.finalize()
dst_meta.info.update(
{
@@ -1533,6 +1537,114 @@ def modify_tasks(
return dataset
def recompute_stats(
dataset: LeRobotDataset,
skip_image_video: bool = True,
relative_action: bool = False,
relative_exclude_joints: list[str] | None = None,
chunk_size: int = 50,
num_workers: int = 0,
) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes.
Args:
dataset: The LeRobotDataset to recompute stats for.
skip_image_video: If True (default), only recompute stats for numeric features
(action, state, etc.) and keep existing image/video stats unchanged.
relative_action: If True, compute action stats in relative space by
iterating all valid action chunks and subtracting the current state.
This matches the normalization distribution the model sees during
training with ``use_relative_actions=True``.
relative_exclude_joints: Joint names to exclude from relative conversion when
relative_action=True. These dims keep absolute stats.
chunk_size: Action chunk size used for relative stats computation. Should match
``policy.chunk_size``. Only used when ``relative_action=True``.
num_workers: Number of parallel threads for relative action stats computation.
Values ≤1 mean single-threaded. Only used when ``relative_action=True``.
Returns:
The same dataset with updated stats.
"""
features = dataset.meta.features
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
numeric_features = {
k: v
for k, v in features.items()
if v["dtype"] not in ["image", "video", "string"] and k not in meta_keys
}
if skip_image_video:
features_to_compute = numeric_features
else:
features_to_compute = {
k: v for k, v in features.items() if v["dtype"] != "string" and k not in meta_keys
}
# When relative_action is enabled, compute action stats via chunk-based sampling
# (matching what the model sees during training) and skip action in the
# per-episode pass below.
relative_action_stats = None
if relative_action and ACTION in features and OBS_STATE in features:
if relative_exclude_joints is None:
relative_exclude_joints = ["gripper"]
relative_action_stats = compute_relative_action_stats(
hf_dataset=dataset.hf_dataset,
features=features,
chunk_size=chunk_size,
exclude_joints=relative_exclude_joints,
num_workers=num_workers,
)
features_to_compute.pop(ACTION, None)
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
df = pd.read_parquet(parquet_path)
for ep_idx in sorted(df["episode_index"].unique()):
ep_df = df[df["episode_index"] == ep_idx]
episode_data = {}
for key in numeric_keys:
if key in ep_df.columns:
values = ep_df[key].values
if hasattr(values[0], "__len__"):
episode_data[key] = np.stack(values)
else:
episode_data[key] = np.array(values)
ep_stats = compute_episode_stats(episode_data, features_to_compute)
all_episode_stats.append(ep_stats)
if features_to_compute and not all_episode_stats:
logging.warning("No episode stats computed")
return dataset
new_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
if relative_action_stats is not None:
new_stats[ACTION] = relative_action_stats
# Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in new_stats:
new_stats[key] = value
write_stats(new_stats, dataset.root)
dataset.meta.stats = new_stats
logging.info("Stats recomputed successfully")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path | None = None,

View File

@@ -0,0 +1,634 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Private writer component for LeRobotDataset. Handles sequential recording (episode buffer, ParquetWriter, image writer, video encoding)."""
from __future__ import annotations
import concurrent.futures
import contextlib
import logging
import shutil
import tempfile
from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.datasets.compute_stats import compute_episode_stats
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
get_hf_features_from_features,
validate_episode_buffer,
validate_frame,
)
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.io_utils import (
embed_images,
get_file_size_in_mb,
load_episodes,
write_info,
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
concatenate_video_files,
encode_video_frames,
get_video_duration_in_s,
)
logger = logging.getLogger(__name__)
def _encode_video_worker(
video_key: str,
episode_index: int,
root: Path,
fps: int,
vcodec: str = "libsvtav1",
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
)
shutil.rmtree(img_dir)
return temp_path
class DatasetWriter:
"""Encapsulates write-side state and methods for LeRobotDataset.
Owns: episode_buffer, image_writer, _pq_writer (ParquetWriter), _latest_episode,
_current_file_start_frame, _streaming_encoder, _episodes_since_last_encoding, _recorded_frames.
"""
def __init__(
self,
meta: LeRobotDatasetMetadata,
root: Path,
vcodec: str,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0,
):
"""Initialize the writer with metadata, codec, and encoding config.
Args:
meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence).
root: Local dataset root directory.
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
encoder_threads: Threads per encoder instance. ``None`` for auto.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
for real-time encoding. ``None`` disables streaming mode.
initial_frames: Starting frame count (non-zero when resuming).
"""
self._meta = meta
self._root = root
self._vcodec = vcodec
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
# Writer state
self.image_writer: AsyncImageWriter | None = None
self.episode_buffer: dict = self._create_episode_buffer()
self._pq_writer: pq.ParquetWriter | None = None
self._latest_episode: dict | None = None
self._current_file_start_frame: int | None = None
self._episodes_since_last_encoding: int = 0
self._recorded_frames: int = initial_frames
self._finalized = False
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self._meta.total_episodes if episode_index is None else episode_index
ep_buffer = {}
ep_buffer["size"] = 0
ep_buffer["task"] = []
for key in self._meta.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
write_image(image, fpath, compress_level=compress_level)
else:
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
def add_frame(self, frame: dict) -> None:
"""
Add a single frame to the current episode buffer.
Apart from images written to a temporary directory, nothing is written to disk
until ``save_episode()`` is called.
The caller must provide all user-defined features plus ``"task"``, and must
not provide ``"timestamp"`` or ``"frame_index"``; those are computed
automatically.
"""
# Convert torch to numpy if needed
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
validate_frame(frame, self._meta.features)
if self.episode_buffer is None:
self.episode_buffer = self._create_episode_buffer()
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = frame_index / self._meta.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task"))
# Start streaming encoder on first frame of episode
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self._meta.video_keys),
temp_dir=self._root,
)
# Add frame features to episode_buffer
for key in frame:
if key not in self._meta.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'."
)
if self._meta.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
self._streaming_encoder.feed_frame(key, frame[key])
self.episode_buffer[key].append(None)
elif self._meta.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
def save_episode(
self,
episode_data: dict | None = None,
parallel_encoding: bool = True,
) -> None:
"""Save the current episode in self.episode_buffer to disk."""
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Update tasks and task indices with new tasks if any
self._meta.save_episode_tasks(episode_tasks)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks])
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
has_video_keys = len(self._meta.video_keys) > 0
use_streaming = self._streaming_encoder is not None and has_video_keys
use_batched_encoding = self._batch_encoding_size > 1
if use_streaming:
non_video_buffer = {
k: v
for k, v in episode_buffer.items()
if self._meta.features.get(k, {}).get("dtype") not in ("video",)
}
non_video_features = {k: v for k, v in self._meta.features.items() if v["dtype"] != "video"}
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
else:
ep_stats = compute_episode_stats(episode_buffer, self._meta.features)
ep_metadata = self._save_episode_data(episode_buffer)
if use_streaming:
streaming_results = self._streaming_encoder.finish_episode()
for video_key in self._meta.video_keys:
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
elif has_video_keys and not use_batched_encoding:
num_cameras = len(self._meta.video_keys)
if parallel_encoding and num_cameras > 1:
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
future_to_key = {
executor.submit(
_encode_video_worker,
video_key,
episode_index,
self._root,
self._meta.fps,
self._vcodec,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
}
results = {}
for future in concurrent.futures.as_completed(future_to_key):
video_key = future_to_key[future]
try:
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logger.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self._meta.video_keys:
temp_path = results[video_key]
ep_metadata.update(
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
)
else:
for video_key in self._meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
if has_video_keys and use_batched_encoding:
self._episodes_since_last_encoding += 1
if self._episodes_since_last_encoding == self._batch_encoding_size:
start_ep = self._meta.total_episodes - self._batch_encoding_size
end_ep = self._meta.total_episodes
self._batch_save_episode_video(start_ep, end_ep)
self._episodes_since_last_encoding = 0
if episode_data is None:
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""Batch save videos for multiple episodes."""
if end_episode is None:
end_episode = self._meta.total_episodes
logger.info(
f"Batch encoding {self._batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
)
chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"]
file_idx = self._meta.episodes[start_episode]["data/file_index"]
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logger.info(f"Encoding videos for episode {ep_idx}")
if (
self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
or self._meta.episodes[ep_idx]["data/file_index"] != file_idx
):
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"]
file_idx = self._meta.episodes[ep_idx]["data/file_index"]
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
episode_df = pd.read_parquet(episode_df_path)
video_ep_metadata = {}
for video_key in self._meta.video_keys:
video_ep_metadata.update(self._save_episode_video(video_key, ep_idx))
video_ep_metadata.pop("episode_index")
video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes(
dtype_backend="pyarrow"
)
episode_df = episode_df.combine_first(video_ep_df)
episode_df.to_parquet(episode_df_path)
self._meta.episodes = load_episodes(self._root)
def _save_episode_data(self, episode_buffer: dict) -> dict:
"""Save episode data to a parquet file."""
# Use metadata features as the authoritative schema
hf_features = get_hf_features_from_features(self._meta.features)
ep_dict = {key: episode_buffer[key] for key in hf_features}
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
ep_num_frames = len(ep_dataset)
if self._latest_episode is None:
chunk_idx, file_idx = 0, 0
global_frame_index = 0
self._current_file_start_frame = 0
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
latest_ep = self._meta.episodes[-1]
global_frame_index = latest_ep["dataset_to_index"]
chunk_idx = latest_ep["data/chunk_index"]
file_idx = latest_ep["data/file_index"]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
self._current_file_start_frame = global_frame_index
else:
latest_ep = self._latest_episode
chunk_idx = latest_ep["data/chunk_index"]
file_idx = latest_ep["data/file_index"]
global_frame_index = latest_ep["index"][-1] + 1
latest_path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_file_size_in_mb(latest_path)
frames_in_current_file = global_frame_index - self._current_file_start_frame
av_size_per_frame = (
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
)
if latest_size_in_mb + av_size_per_frame * ep_num_frames >= self._meta.data_files_size_in_mb:
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
self.close_writer()
self._current_file_start_frame = global_frame_index
ep_dict["data/chunk_index"] = chunk_idx
ep_dict["data/file_index"] = file_idx
path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
table = ep_dataset.with_format("arrow")[:]
if not self._pq_writer:
self._pq_writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self._pq_writer.write_table(table)
metadata = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"dataset_from_index": global_frame_index,
"dataset_to_index": global_frame_index + ep_num_frames,
}
self._latest_episode = {**ep_dict, **metadata}
self._recorded_frames += ep_num_frames
return metadata
def _save_episode_video(
self,
video_key: str,
episode_index: int,
temp_path: Path | None = None,
) -> dict:
if temp_path is None:
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
else:
ep_path = temp_path
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
if (
episode_index == 0
or self._meta.latest_episode is None
or f"videos/{video_key}/chunk_index" not in self._meta.latest_episode
):
chunk_idx, file_idx = 0, 0
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
old_chunk_idx = self._meta.episodes[-1][f"videos/{video_key}/chunk_index"]
old_file_idx = self._meta.episodes[-1][f"videos/{video_key}/file_index"]
chunk_idx, file_idx = update_chunk_file_indices(
old_chunk_idx, old_file_idx, self._meta.chunks_size
)
latest_duration_in_s = 0.0
new_path = self._root / self._meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
latest_ep = self._meta.latest_episode
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0]
file_idx = latest_ep[f"videos/{video_key}/file_index"][0]
latest_path = self._root / self._meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
latest_size_in_mb = get_file_size_in_mb(latest_path)
latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0]
if latest_size_in_mb + ep_size_in_mb >= self._meta.video_files_size_in_mb:
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
new_path = self._root / self._meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
latest_duration_in_s = 0.0
else:
concatenate_video_files(
[latest_path, ep_path],
latest_path,
)
# Remove temporary directory
shutil.rmtree(str(ep_path.parent))
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key)
write_info(self._meta.info, self._meta.root)
metadata = {
"episode_index": episode_index,
f"videos/{video_key}/chunk_index": chunk_idx,
f"videos/{video_key}/file_index": file_idx,
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
return metadata
def clear_episode_buffer(self, delete_images: bool = True) -> None:
"""Discard the current episode buffer and optionally delete temp images.
Args:
delete_images: If ``True``, remove temporary image directories
written for the current episode.
"""
# Cancel streaming encoder if active
if self._streaming_encoder is not None:
self._streaming_encoder.cancel_episode()
if delete_images:
if self.image_writer is not None:
self._wait_image_writer()
episode_index = self.episode_buffer["episode_index"]
# episode_index is `int` when freshly created, but becomes `np.ndarray` after
# save_episode() mutates the buffer. Handle both types here.
if isinstance(episode_index, np.ndarray):
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for cam_key in self._meta.image_keys:
img_dir = self._get_image_file_dir(episode_index, cam_key)
if img_dir.is_dir():
shutil.rmtree(img_dir)
self.episode_buffer = self._create_episode_buffer()
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
"""Start an :class:`AsyncImageWriter` for background image persistence.
Args:
num_processes: Number of subprocesses. ``0`` means threads only.
num_threads: Number of threads per process.
"""
if isinstance(self.image_writer, AsyncImageWriter):
logger.warning(
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
)
self.image_writer = AsyncImageWriter(
num_processes=num_processes,
num_threads=num_threads,
)
def stop_image_writer(self) -> None:
"""Stop the image writer (needed before pickling the dataset for DataLoader)."""
if self.image_writer is not None:
self.image_writer.stop()
self.image_writer = None
def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish."""
if self.image_writer is not None:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker(
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
)
def close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
if self._pq_writer is not None:
self._pq_writer.close()
self._pq_writer = None
def flush_pending_videos(self) -> None:
"""Flush any pending video encoding (streaming or batch).
For streaming encoding: closes the encoder.
For batch encoding: encodes any remaining episodes that haven't been batch-encoded yet.
"""
if self._streaming_encoder is not None:
self._streaming_encoder.close()
elif self._episodes_since_last_encoding > 0:
start_ep = self._meta.total_episodes - self._episodes_since_last_encoding
end_ep = self._meta.total_episodes
logger.info(
f"Encoding remaining {self._episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self._batch_save_episode_video(start_ep, end_ep)
def cancel_pending_videos(self) -> None:
"""Cancel any in-progress streaming encoding without flushing."""
if self._streaming_encoder is not None:
self._streaming_encoder.cancel_episode()
def cleanup_interrupted_episode(self, episode_index: int) -> None:
"""Remove temporary image directories for an interrupted episode."""
for key in self._meta.video_keys:
img_dir = self._get_image_file_path(
episode_index=episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logger.debug(
f"Cleaning up interrupted episode images for episode {episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
def finalize(self) -> None:
"""Flush all pending work and release all resources.
Idempotent — safe to call multiple times.
"""
if getattr(self, "_finalized", False):
return
# 1. Wait for async image writes to complete, then stop
if self.image_writer is not None:
self.image_writer.wait_until_done()
self.image_writer.stop()
self.image_writer = None
# 2. Flush pending video encoding (streaming or batch)
self.flush_pending_videos()
# 3. Close own parquet writer
self.close_writer()
# 4. Finalize metadata (idempotent)
self._meta.finalize()
self._finalized = True
def __del__(self):
"""Safety net: release resources on garbage collection."""
# During interpreter shutdown, referenced objects may already be collected.
with contextlib.suppress(Exception):
self.finalize()

View File

@@ -365,6 +365,10 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
def validate_frame(frame: dict, features: dict) -> None:
# DEFAULT_FEATURES (timestamp, frame_index, episode_index, index, task_index) are
# auto-populated by the recording pipeline (add_frame / save_episode) and must not
# be supplied by the caller. Excluding them here means any frame dict that contains
# these keys will be rejected as extra features.
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)

View File

@@ -32,10 +32,10 @@ def safe_stop_image_writer(func):
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
writer = getattr(dataset, "writer", None) if dataset else None
if writer is not None and writer.image_writer is not None:
logger.warning("Waiting for image writer to terminate...")
image_writer.stop()
writer.image_writer.stop()
raise e
return wrapper

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ import torch
import torch.utils
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import VideoFrame
from lerobot.utils.constants import HF_LEROBOT_HOME
@@ -125,7 +126,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
features.update(
{
k: v
for k, v in get_hf_features_from_features(dataset.features).items()
if k not in self.disabled_features
}
)
return features
@property

View File

@@ -255,7 +255,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory to use for downloading/writing files.
root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub
metadata is resolved through a revision-safe snapshot cache under
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list.
image_transforms (Callable | None, optional): Transform to apply to image data.
@@ -271,7 +273,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self._requested_root = Path(root) if root else None
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
self.streaming_from_local = root is not None
self.image_transforms = image_transforms
@@ -288,12 +291,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self.root.mkdir(exist_ok=True, parents=True)
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
)
self.root = self.meta.root
self.revision = self.meta.revision
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)

View File

@@ -18,6 +18,7 @@ import importlib.resources
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import datasets
@@ -101,6 +102,18 @@ DEFAULT_FEATURES = {
}
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
``snapshot_download(local_dir=...)`` stores lightweight metadata under
``<local_dir>/.cache/huggingface/download/``. The presence of this
directory is a reliable indicator that the dataset was downloaded with
the old non-revision-safe ``local_dir`` mode and should be re-fetched
through the snapshot cache instead.
"""
return (root / ".cache" / "huggingface" / "download").exists()
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0

View File

@@ -741,6 +741,7 @@ class StreamingVideoEncoder:
self._video_paths: dict[str, Path] = {}
self._dropped_frames: dict[str, int] = {}
self._episode_active = False
self._closed = False
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
"""Start encoder threads for a new episode.
@@ -895,8 +896,11 @@ class StreamingVideoEncoder:
def close(self) -> None:
"""Close the encoder, canceling any in-progress episode."""
if self._closed:
return
if self._episode_active:
self.cancel_episode()
self._closed = True
def _cleanup(self) -> None:
"""Clean up queues and thread tracking dicts."""
@@ -1063,43 +1067,19 @@ class VideoEncodingManager:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
writer = self.dataset.writer
if writer is not None:
if exc_type is not None and writer._streaming_encoder is not None:
writer.cancel_pending_videos()
if streaming_encoder is not None:
# Handle streaming encoder cleanup
if exc_type is not None:
streaming_encoder.cancel_episode()
streaming_encoder.close()
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
if exc_type is not None:
logger.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logger.info("Recording stopped. Encoding remaining episodes...")
# finalize() handles flush_pending_videos + parquet + metadata
self.dataset.finalize()
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logger.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset._batch_save_episode_video(start_ep, end_ep)
# Finalize the dataset to properly close all writers
self.dataset.finalize()
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and streaming_encoder is None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logger.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and writer._streaming_encoder is None:
writer.cleanup_interrupted_episode(self.dataset.num_episodes)
else:
self.dataset.finalize()
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"

View File

@@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import importlib
from dataclasses import dataclass, field, fields
from typing import Any
import draccus
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
@@ -67,6 +72,45 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def gym_kwargs(self) -> dict:
raise NotImplementedError()
def create_envs(
self,
n_envs: int,
use_async_envs: bool = True,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create {suite: {task_id: VectorEnv}}.
Default: single-task env via gym.make(). Multi-task benchmarks override.
AsyncVectorEnv is the default for n_envs > 1; auto-downgraded to Sync for n_envs=1.
"""
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
if self.gym_id not in gym_registry:
print(f"gym id '{self.gym_id}' not found, attempting to import '{self.package_name}'...")
try:
importlib.import_module(self.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{self.package_name}' required for env '{self.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if self.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{self.gym_id}' not registered even after importing '{self.package_name}'."
)
def _make_one():
return gym.make(self.gym_id, disable_env_checker=self.disable_env_checker, **self.gym_kwargs)
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
return {self.type: {0: vec}}
def get_env_processors(self):
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
@dataclass
class HubEnvConfig(EnvConfig):
@@ -345,6 +389,32 @@ class LiberoEnv(EnvConfig):
kwargs["task_ids"] = self.task_ids
return kwargs
def create_envs(self, n_envs: int, use_async_envs: bool = True):
from lerobot.envs.libero import create_libero_envs
if self.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
return create_libero_envs(
task=self.task,
n_envs=n_envs,
camera_name=self.camera_name,
init_states=self.init_states,
gym_kwargs=self.gym_kwargs,
env_cls=env_cls,
control_mode=self.control_mode,
episode_length=self.episode_length,
)
def get_env_processors(self):
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
return (
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
)
@EnvConfig.register_subclass("metaworld")
@dataclass
@@ -387,6 +457,19 @@ class MetaworldEnv(EnvConfig):
"render_mode": self.render_mode,
}
def create_envs(self, n_envs: int, use_async_envs: bool = True):
from lerobot.envs.metaworld import create_metaworld_envs
if self.task is None:
raise ValueError("MetaWorld requires a task to be specified")
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
return create_metaworld_envs(
task=self.task,
n_envs=n_envs,
gym_kwargs=self.gym_kwargs,
env_cls=env_cls,
)
@EnvConfig.register_subclass("isaaclab_arena")
@dataclass
@@ -454,3 +537,18 @@ class IsaaclabArenaEnv(HubEnvConfig):
@property
def gym_kwargs(self) -> dict:
return {}
def get_env_processors(self):
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
return (
PolicyProcessorPipeline(
steps=[IsaaclabArenaProcessorStep(state_keys=state_keys, camera_keys=camera_keys)]
),
PolicyProcessorPipeline(steps=[]),
)

View File

@@ -13,96 +13,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from __future__ import annotations
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.configs import EnvConfig, HubEnvConfig
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
if env_type == "aloha":
return AlohaEnv(**kwargs)
elif env_type == "pusht":
return PushtEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
try:
cls = EnvConfig.get_choice_class(env_type)
except KeyError as err:
raise ValueError(
f"Environment type '{env_type}' is not registered. "
f"Available: {list(EnvConfig.get_known_choices().keys())}"
) from err
return cls(**kwargs)
def make_env_pre_post_processors(
env_cfg: EnvConfig,
policy_cfg: PreTrainedConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
policy_cfg: Any,
) -> tuple[Any, Any]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
Returns a tuple of (preprocessor, postprocessor). By default, delegates to
``env_cfg.get_env_processors()``. The XVLAConfig policy-specific override
stays here because it depends on the *policy* config, not the env config.
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
if env_cfg.state_keys:
state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip())
else:
state_keys = ()
if env_cfg.camera_keys:
camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip())
else:
camera_keys = ()
if not state_keys and not camera_keys:
raise ValueError("At least one of state_keys or camera_keys must be specified.")
preprocessor_steps.append(
IsaaclabArenaProcessorStep(
state_keys=state_keys,
camera_keys=camera_keys,
)
)
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
return env_cfg.get_env_processors()
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
use_async_envs: bool = False,
use_async_envs: bool = True,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
@@ -163,57 +119,4 @@ def make_env(
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
if cfg.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
return create_libero_envs(
task=cfg.task,
n_envs=n_envs,
camera_name=cfg.camera_name,
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
if cfg.task is None:
raise ValueError("MetaWorld requires a task to be specified")
return create_metaworld_envs(
task=cfg.task,
n_envs=n_envs,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one():
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
# normalize to {suite: {task_id: vec_env}} for consistency
suite_name = cfg.type # e.g., "pusht", "aloha"
return {suite_name: {0: vec}}
return cfg.create_envs(n_envs=n_envs, use_async_envs=use_async_envs)

View File

@@ -150,7 +150,17 @@ class LiberoEnv(gym.Env):
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
# Extract task metadata without allocating GPU resources (safe before fork).
task = task_suite.get_task(task_id)
self.task = task.name
self.task_description = task.language
self._task_bddl_file = os.path.join(
get_libero_path("bddl_files"), task.problem_folder, task.bddl_file
)
self._env: OffScreenRenderEnv | None = (
None # deferred — created on first reset() inside the worker subprocess
)
default_steps = 500
self._max_episode_steps = (
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
@@ -221,28 +231,32 @@ class LiberoEnv(gym.Env):
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
)
def _ensure_env(self) -> None:
"""Create the underlying OffScreenRenderEnv on first use.
Called inside the worker subprocess after fork(), so each worker gets
its own clean EGL context rather than inheriting a stale one from the
parent process (which causes EGL_BAD_CONTEXT crashes with AsyncVectorEnv).
"""
if self._env is not None:
return
env = OffScreenRenderEnv(
bddl_file_name=self._task_bddl_file,
camera_heights=self.observation_height,
camera_widths=self.observation_width,
)
env.reset()
self._env = env
def render(self):
self._ensure_env()
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
task = task_suite.get_task(task_id)
self.task = task.name
self.task_description = task.language
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
env_args = {
"bddl_file_name": task_bddl_file,
"camera_heights": self.observation_height,
"camera_widths": self.observation_width,
}
env = OffScreenRenderEnv(**env_args)
env.reset()
return env
def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation:
assert self._env is not None, "_format_raw_obs called before _ensure_env()"
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
@@ -294,6 +308,7 @@ class LiberoEnv(gym.Env):
)
def reset(self, seed=None, **kwargs):
self._ensure_env()
super().reset(seed=seed)
self._env.seed(seed)
raw_obs = self._env.reset()
@@ -320,6 +335,8 @@ class LiberoEnv(gym.Env):
return observation, info
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
self._ensure_env()
assert self._env is not None
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
@@ -350,7 +367,8 @@ class LiberoEnv(gym.Env):
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
if self._env is not None:
self._env.close()
def _make_env_fns(

View File

@@ -97,8 +97,9 @@ class MetaworldEnv(gym.Env):
self.visualization_height = visualization_height
self.camera_name = camera_name
self._env = self._make_envs_task(self.task)
self._max_episode_steps = self._env.max_path_length
self._env_name = self.task # already stripped of "metaworld-" prefix above
self._env = None # deferred — created on first reset() inside the worker subprocess
self._max_episode_steps = 500 # MT1 environments always have max_path_length=500
self.task_description = TASK_DESCRIPTIONS[self.task]
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
@@ -136,6 +137,24 @@ class MetaworldEnv(gym.Env):
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
def _ensure_env(self) -> None:
"""Create the underlying MetaWorld env on first use.
Called inside the worker subprocess after fork(), so each worker gets
its own clean rendering context rather than inheriting a stale one from
the parent process (which causes crashes with AsyncVectorEnv).
"""
if self._env is not None:
return
mt1 = metaworld.MT1(self._env_name, seed=42)
env = mt1.train_classes[self._env_name](render_mode="rgb_array", camera_name=self.camera_name)
env.set_task(mt1.train_tasks[0])
if self.camera_name == "corner2":
env.model.cam_pos[2] = [0.75, 0.075, 0.7]
env.reset()
env._freeze_rand_vec = False # otherwise no randomization
self._env = env
def render(self) -> np.ndarray:
"""
Render the current environment frame.
@@ -143,26 +162,13 @@ class MetaworldEnv(gym.Env):
Returns:
np.ndarray: The rendered RGB image from the environment.
"""
self._ensure_env()
image = self._env.render()
if self.camera_name == "corner2":
# Images from this camera are flipped — correct them
image = np.flip(image, (0, 1))
return image
def _make_envs_task(self, env_name: str):
mt1 = metaworld.MT1(env_name, seed=42)
env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
env.set_task(mt1.train_tasks[0])
if self.camera_name == "corner2":
env.model.cam_pos[2] = [
0.75,
0.075,
0.7,
] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
env.reset()
env._freeze_rand_vec = False # otherwise no randomization
return env
def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation:
image = None
if self._env is not None:
@@ -209,6 +215,7 @@ class MetaworldEnv(gym.Env):
observation (RobotObservation): The initial formatted observation.
info (Dict[str, Any]): Additional info about the reset state.
"""
self._ensure_env()
super().reset(seed=seed)
raw_obs, info = self._env.reset(seed=seed)
@@ -232,6 +239,7 @@ class MetaworldEnv(gym.Env):
truncated (bool): Whether the episode was truncated due to a time limit.
info (Dict[str, Any]): Additional environment info.
"""
self._ensure_env()
if action.ndim != 1:
raise ValueError(
f"Expected action to be 1-D (shape (action_dim,)), "
@@ -263,7 +271,8 @@ class MetaworldEnv(gym.Env):
return observation, reward, terminated, truncated, info
def close(self):
self._env.close()
if self._env is not None:
self._env.close()
# ---- Main API ----------------------------------------------------------------

View File

@@ -15,6 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
@@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"PI0FastConfig",

View File

@@ -31,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -58,6 +59,29 @@ from lerobot.utils.constants import (
)
def _reconnect_relative_absolute_steps(
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
) -> None:
"""Wire AbsoluteActionsProcessorStep.relative_step to the RelativeActionsProcessorStep after deserialization.
After a policy is loaded from disk, the preprocessor and postprocessor are reconstructed
independently from their configs. AbsoluteActionsProcessorStep needs a live reference to
the RelativeActionsProcessorStep so it can read the cached state at inference time.
That reference is not serializable, so we re-establish it here after loading.
"""
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
RelativeActionsProcessorStep,
)
relative_step = next((s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep)), None)
if relative_step is None:
return
for step in postprocessor.steps:
if isinstance(step, AbsoluteActionsProcessorStep) and step.relative_step is None:
step.relative_step = relative_step
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""
Retrieves a policy class by its registered name.
@@ -67,8 +91,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -87,6 +110,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
@@ -147,8 +174,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -163,6 +190,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
@@ -263,26 +292,26 @@ def make_pre_post_processors(
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
),
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
# Create a new processor based on policy type
if isinstance(policy_cfg, TDMPCConfig):
@@ -309,6 +338,16 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
@@ -470,6 +509,13 @@ def make_policy(
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names for relative_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)

View File

@@ -0,0 +1,37 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_multi_task_dit import MultiTaskDiTConfig
from .modeling_multi_task_dit import MultiTaskDiTPolicy
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]

View File

@@ -0,0 +1,256 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
"""
n_obs_steps: int = 2 # Number of observation steps for temporal context
horizon: int = 32 # Number of action steps to predict
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
# Objective Selection
objective: str = "diffusion" # "diffusion" or "flow_matching"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Training/Optimizer
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self):
super().__post_init__()
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
self._validate()
def _validate(self):
"""Validate configuration parameters."""
# Objective validation
if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
logging.warning(
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
self.image_crop_shape,
self.image_resize_shape,
)
self.image_crop_shape = None
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# If the configured crop doesn't fit, disable cropping instead of erroring.
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
# image_ft.shape is (C, H, W)
effective_h, effective_w = (
self.image_resize_shape
if self.image_resize_shape is not None
else (image_ft.shape[1], image_ft.shape[2])
)
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
logging.warning(
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
self.image_crop_shape,
effective_h,
effective_w,
key,
)
self.image_crop_shape = None
break
if len(self.image_features) > 0:
first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_ft.shape:
raise ValueError(
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
)
@property
def is_diffusion(self) -> bool:
return self.objective == "diffusion"
@property
def is_flow_matching(self) -> bool:
return self.objective == "flow_matching"
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,803 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Task Diffusion Transformer (DiT) Policy
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
Supports both diffusion and flow matching objectives for action generation.
References:
- https://arxiv.org/abs/2507.05331
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
"""
import math
from collections import deque
from typing import TYPE_CHECKING
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import CLIPTextModel, CLIPVisionModel
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy --
class MultiTaskDiTPolicy(PreTrainedPolicy):
config_class = MultiTaskDiTConfig
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
super().__init__(config)
config.validate_features()
self.config = config
self._queues = None
self.observation_encoder = ObservationEncoder(config)
conditioning_dim = self.observation_encoder.conditioning_dim
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
action_dim = config.action_feature.shape[0]
horizon = config.horizon
if config.is_diffusion:
self.objective = DiffusionObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported objective: {config.objective}")
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
non_vision_params = []
vision_encoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "observation_encoder.vision_encoder" in name:
vision_encoder_params.append(param)
else:
non_vision_params.append(param)
return [
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
},
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
actions = self._generate_actions(batch)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
batch = self._prepare_batch(batch)
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
# -- Observation Encoders --
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
self.model = CLIPVisionModel.from_pretrained(self.model_name)
self.num_non_spatial_tokens = 1
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token."""
outputs = self.model(pixel_values=x, output_hidden_states=False)
cls_token = outputs.last_hidden_state[:, 0]
b, embed_dim = cls_token.shape
return cls_token.reshape(b, embed_dim, 1, 1)
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode pre-tokenized text to feature vectors."""
# Ensure inputs are on the same device as the model
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output
return self.projection(clip_features)
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES]
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_rgb_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
cam_images_flat = self._apply_preprocessing(cam_images_flat)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
cam_visual_features = cam_features.flatten(start_dim=1)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1)
return combined_features.flatten(start_dim=1)
# -- Transformer Components --
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero."""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers."""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape # noqa: N806
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.rope(q, k)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
return self.out_proj(attn_out)
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero."""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
self._initialize_weights()
def _initialize_weights(self):
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep)
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
hidden_seq = self.input_proj(x)
if self.pos_embedding is not None:
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features)
return self.output_proj(hidden_seq)
# -- Objectives --
class DiffusionObjective(nn.Module):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch[ACTION]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(nn.Module):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch[ACTION]
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1)
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x

View File

@@ -0,0 +1,105 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Adding a batch dimension.
3. Tokenizing the language task description (if present).
4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale.
2. Moving the data to the CPU.
Args:
config: The configuration object for the Multi-Task DiT policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -17,6 +17,65 @@ It is designed as a **Vision-Language-Action model for general robot control**.
---
## Relative Actions
π₀ supports training with **relative actions**, where the model learns relative offsets
from the current robot state instead of absolute joint positions. This mirrors the
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
### How it works
1. **During preprocessing**, absolute actions are converted to relative offsets:
`relative = action - state` (for selected joints).
2. The relative actions are normalized using statistics computed from the relative distribution.
3. **During postprocessing**, predicted relative actions are converted back to absolute:
`absolute = relative + state`.
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
### Configuration
| Parameter | Type | Default | Description |
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
### Training example
```bash
python -m lerobot.scripts.lerobot_train \
--policy.type=pi0 \
--dataset.repo_id=your_org/your_dataset \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
When `use_relative_actions=true`, the training script automatically:
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
- Replaces the standard action stats with relative stats for normalization
- Broadcasts these stats across all ranks in distributed training
### Recomputing stats for an existing dataset
If you want to precompute relative action stats offline, use `recompute_stats` from
`lerobot.datasets.dataset_tools`:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.dataset_tools import recompute_stats
dataset = LeRobotDataset("your_org/your_dataset")
dataset = recompute_stats(
dataset,
relative_action=True,
relative_exclude_joints=["gripper"],
)
```
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀ paper:

View File

@@ -50,6 +50,13 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Relative actions: converts absolute actions to relative (relative to state).
use_relative_actions: bool = False
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None

View File

@@ -21,6 +21,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
@@ -29,6 +30,7 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -17,6 +17,48 @@ It is designed as a **Vision-Language-Action model with open-world generalizatio
---
## Relative Actions
π₀.₅ supports training with **relative actions**, where the model learns relative offsets
from the current robot state instead of absolute joint positions. This mirrors the
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
### How it works
1. **During preprocessing**, absolute actions are converted to relative offsets:
`relative = action - state` (for selected joints).
2. The relative actions are normalized using statistics computed from the relative distribution.
3. **During postprocessing**, predicted relative actions are converted back to absolute:
`absolute = relative + state`.
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
### Configuration
| Parameter | Type | Default | Description |
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
### Training example
```bash
python -m lerobot.scripts.lerobot_train \
--policy.type=pi05 \
--dataset.repo_id=your_org/your_dataset \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
When `use_relative_actions=true`, the training script automatically:
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
- Replaces the standard action stats with relative stats for normalization
- Broadcasts these stats across all ranks in distributed training
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:

View File

@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Relative actions: converts absolute actions to relative (relative to state).
use_relative_actions: bool = False
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None

View File

@@ -24,6 +24,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
@@ -31,6 +32,7 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -125,10 +127,17 @@ def make_pi05_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
relative_step,
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
@@ -150,6 +159,7 @@ def make_pi05_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -41,6 +41,13 @@ class PI0FastConfig(PreTrainedConfig):
max_action_dim: int = 32
max_action_tokens: int = 256
# Relative actions: converts absolute actions to relative (relative to state).
use_relative_actions: bool = False
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None

View File

@@ -24,6 +24,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -32,6 +33,7 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -125,12 +127,24 @@ def make_pi0_fast_pre_post_processors(
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# Pi0Fast order: relative → normalize → tokenize → model → unnormalize → absolute
# This matches pi0/pi0.5: RelativeActionsProcessorStep runs first on raw absolute actions,
# caching the raw state. NormalizerProcessorStep then normalizes the raw relative actions,
# so the normalizer (and action tokenizer) sees delta values — relative stats are required.
# NOTE: RelativeActionsProcessorStep only modifies the action in the transition; it reads
# state from the observation but does not change it. NormalizerProcessorStep still runs
# before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep, so the state tokenizer
# continues to receive normalized state in [-1, 1] as expected.
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -156,6 +170,7 @@ def make_pi0_fast_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
DeviceProcessorStep(device="cpu"),
]

View File

@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
__all__ = [
"ActionInterpolator",
"ActionQueue",
"LatencyTracker",
"RTCConfig",
"RTCProcessor",
]

View File

@@ -0,0 +1,116 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)

View File

@@ -79,6 +79,13 @@ class ActionQueue:
self.last_index += 1
return action.clone()
def clear(self) -> None:
"""Clear queued actions and reset consumption index."""
with self.lock:
self.queue = None
self.original_queue = None
self.last_index = 0
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
@@ -123,14 +130,26 @@ class ActionQueue:
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
return self.original_queue[self.last_index :].clone()
def get_processed_left_over(self) -> Tensor | None:
"""Get leftover processed actions (the actions currently executed by the robot).
Returns:
Tensor | None: Remaining processed actions (remaining_steps, action_dim),
or None if no processed queue exists.
"""
with self.lock:
if self.queue is None:
return None
return self.queue[self.last_index :].clone()
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
action_index_before_inference: int | None = None,
):
"""Merge new actions into the queue.
@@ -145,10 +164,10 @@ class ActionQueue:
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
delay = self._check_and_resolve_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
self._replace_actions_queue(original_actions, processed_actions, delay)
return
self._append_actions_queue(original_actions, processed_actions)
@@ -164,12 +183,13 @@ class ActionQueue:
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
clamped_delay = max(0, min(real_delay, len(original_actions), len(processed_actions)))
self.original_queue = original_actions[clamped_delay:].clone()
self.queue = processed_actions[clamped_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
logger.debug(f"real_delay: {real_delay}, clamped_delay: {clamped_delay}")
self.last_index = 0
@@ -196,7 +216,9 @@ class ActionQueue:
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
def _check_and_resolve_delays(
self, real_delay: int, action_index_before_inference: int | None = None
) -> int:
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
@@ -205,15 +227,20 @@ class ActionQueue:
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
Returns:
int: Delay to use.
"""
effective_delay = max(0, real_delay)
if action_index_before_inference is not None:
indexes_diff = max(0, self.last_index - action_index_before_inference)
if indexes_diff != real_delay:
logger.warning(
"Indexes diff is not equal to real delay. indexes_diff=%d, real_delay=%d",
indexes_diff,
real_delay,
)
return real_delay
return effective_delay

View File

@@ -55,7 +55,7 @@ class SmolVLAConfig(PreTrainedConfig):
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Converts joint dimensions to relative values with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False

View File

@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
self.vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
self.vqvae_model.discretized = torch.tensor(True)
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
self.vqvae_model.discretized.fill_(True)
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
print("Finished discretizing action data!")
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():

View File

@@ -75,6 +75,12 @@ from .policy_robot_bridge import (
PolicyActionToRobotActionProcessorStep,
RobotActionToPolicyActionProcessorStep,
)
from .relative_action_processor import (
AbsoluteActionsProcessorStep,
RelativeActionsProcessorStep,
to_absolute_actions,
to_relative_actions,
)
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
@@ -100,6 +106,8 @@ __all__ = [
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"RelativeActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep",
@@ -129,6 +137,8 @@ __all__ = [
"transition_to_batch",
"TransitionKey",
"TruncatedProcessorStep",
"to_absolute_actions",
"to_relative_actions",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]

View File

@@ -0,0 +1,208 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_STATE
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .pipeline import ProcessorStep, ProcessorStepRegistry
# Re-export for backward compatibility
__all__ = [
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"RelativeActionsProcessorStep",
"AbsoluteActionsProcessorStep",
"to_relative_actions",
"to_absolute_actions",
]
def to_relative_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert absolute actions to relative: relative = action - state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
# Align state to the same device/dtype as actions. _last_state is cached before
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
if state.device != actions.device or state.dtype != actions.dtype:
state = state.to(device=actions.device, dtype=actions.dtype)
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] -= state_offset
return actions
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert relative actions back to absolute: absolute = relative + state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
# Align state to the same device/dtype as actions. _last_state is cached before
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
if state.device != actions.device or state.dtype != actions.dtype:
state = state.to(device=actions.device, dtype=actions.dtype)
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] += state_offset
return actions
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes:
enabled: Whether to apply the relative conversion.
exclude_joints: Joint names to keep absolute (not converted to relative).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask = []
for name in self.action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"exclude_joints": self.exclude_joints,
"action_names": self.action_names,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts relative actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted relative offsets are converted back to absolute positions for execution.
Reads the cached state from its paired RelativeActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
relative_step: Reference to the paired RelativeActionsProcessorStep that caches state.
"""
enabled: bool = False
relative_step: RelativeActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
if self.relative_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired RelativeActionsProcessorStep "
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
)
if self.relative_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
mask = self.relative_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions(
action, self.relative_step._last_state, mask
)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features

View File

@@ -563,7 +563,7 @@ class ReplayBuffer:
)
# Start writing images if needed
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
lerobot_dataset.writer.start_image_writer(num_processes=0, num_threads=3)
# Convert transitions into episodes and frames
@@ -603,10 +603,10 @@ class ReplayBuffer:
lerobot_dataset.save_episode()
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
if lerobot_dataset.has_pending_frames():
lerobot_dataset.save_episode()
lerobot_dataset.stop_image_writer()
lerobot_dataset.writer.stop_image_writer()
lerobot_dataset.finalize()
return lerobot_dataset

View File

@@ -752,8 +752,7 @@ def replay_trajectory(
episodes=[cfg.dataset.replay_episode],
download_videos=False,
)
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
_, info = env.reset()

View File

@@ -39,13 +39,23 @@ class BiOpenArmFollower(Robot):
super().__init__(config)
self.config = config
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras,
cameras=left_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
@@ -63,7 +73,7 @@ class BiOpenArmFollower(Robot):
port=config.right_arm_config.port,
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
cameras=right_cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
@@ -86,20 +96,19 @@ class BiOpenArmFollower(Robot):
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -139,13 +148,19 @@ class BiOpenArmFollower(Robot):
def get_observation(self) -> RobotObservation:
obs_dict = {}
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
# Add "right_" prefix
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict
@@ -172,7 +187,7 @@ class BiOpenArmFollower(Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):

View File

@@ -14,17 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
from ..config import RobotConfig
@RobotConfig.register_subclass("bi_openarm_follower")
@dataclass
@dataclass(kw_only=True)
class BiOpenArmFollowerConfig(RobotConfig):
"""Configuration class for Bi OpenArm Follower robots."""
id: str | None = "bi_openarm_follower"
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras shared across both arms.
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -33,21 +33,40 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
logger = logging.getLogger(__name__)
# Action feature keys
ACTION_LINEAR_VEL = "linear.vel"
ACTION_ANGULAR_VEL = "angular.vel"
ACTION_LINEAR_VEL = "linear_velocity"
ACTION_ANGULAR_VEL = "angular_velocity"
# Observation feature keys
# Observation feature keys — cameras
OBS_FRONT = "front"
OBS_REAR = "rear"
OBS_LINEAR_VEL = "linear.vel"
OBS_BATTERY_LEVEL = "battery.level"
OBS_ORIENTATION_DEG = "orientation.deg"
OBS_GPS_LATITUDE = "gps.latitude"
OBS_GPS_LONGITUDE = "gps.longitude"
OBS_GPS_SIGNAL = "gps.signal"
OBS_SIGNAL_LEVEL = "signal.level"
# Observation feature keys — telemetry
OBS_SPEED = "speed"
OBS_BATTERY_LEVEL = "battery_level"
OBS_ORIENTATION = "orientation"
OBS_GPS_LATITUDE = "gps_latitude"
OBS_GPS_LONGITUDE = "gps_longitude"
OBS_GPS_SIGNAL = "gps_signal"
OBS_SIGNAL_LEVEL = "signal_level"
OBS_VIBRATION = "vibration"
OBS_LAMP_STATE = "lamp.state"
OBS_LAMP = "lamp"
# Observation feature keys — IMU sensors
OBS_ACCELEROMETER_X = "accelerometer_x"
OBS_ACCELEROMETER_Y = "accelerometer_y"
OBS_ACCELEROMETER_Z = "accelerometer_z"
OBS_GYROSCOPE_X = "gyroscope_x"
OBS_GYROSCOPE_Y = "gyroscope_y"
OBS_GYROSCOPE_Z = "gyroscope_z"
OBS_MAGNETOMETER_X = "magnetometer_filtered_x"
OBS_MAGNETOMETER_Y = "magnetometer_filtered_y"
OBS_MAGNETOMETER_Z = "magnetometer_filtered_z"
# Observation feature keys — wheel RPMs
OBS_WHEEL_RPM_0 = "wheel_rpm_0"
OBS_WHEEL_RPM_1 = "wheel_rpm_1"
OBS_WHEEL_RPM_2 = "wheel_rpm_2"
OBS_WHEEL_RPM_3 = "wheel_rpm_3"
class EarthRoverMiniPlus(Robot):
@@ -154,33 +173,60 @@ class EarthRoverMiniPlus(Robot):
dict: Observation features with types/shapes:
- front: (480, 640, 3) - Front camera RGB image
- rear: (480, 640, 3) - Rear camera RGB image
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
- battery.level: float - Battery level (0-1, normalized from 0-100)
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
- gps.latitude: float - GPS latitude coordinate
- gps.longitude: float - GPS longitude coordinate
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
- signal.level: float - Network signal level (0-1, normalized from 0-5)
- speed: float - Current speed (raw SDK value)
- battery_level: float - Battery level (0-100)
- orientation: float - Robot orientation in degrees
- gps_latitude: float - GPS latitude coordinate
- gps_longitude: float - GPS longitude coordinate
- gps_signal: float - GPS signal strength (percentage)
- signal_level: float - Network signal level (0-5)
- vibration: float - Vibration sensor reading
- lamp.state: float - Lamp state (0=off, 1=on)
- lamp: float - Lamp state (0=off, 1=on)
- accelerometer_x: float - Accelerometer X axis (raw SDK value)
- accelerometer_y: float - Accelerometer Y axis (raw SDK value)
- accelerometer_z: float - Accelerometer Z axis (raw SDK value)
- gyroscope_x: float - Gyroscope X axis (raw SDK value)
- gyroscope_y: float - Gyroscope Y axis (raw SDK value)
- gyroscope_z: float - Gyroscope Z axis (raw SDK value)
- magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value)
- magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value)
- magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value)
- wheel_rpm_0: float - Wheel 0 RPM
- wheel_rpm_1: float - Wheel 1 RPM
- wheel_rpm_2: float - Wheel 2 RPM
- wheel_rpm_3: float - Wheel 3 RPM
"""
return {
# Cameras (height, width, channels)
OBS_FRONT: (480, 640, 3),
OBS_REAR: (480, 640, 3),
# Motion state
OBS_LINEAR_VEL: float,
# Robot state
# Telemetry
OBS_SPEED: float,
OBS_BATTERY_LEVEL: float,
OBS_ORIENTATION_DEG: float,
# GPS
OBS_ORIENTATION: float,
OBS_GPS_LATITUDE: float,
OBS_GPS_LONGITUDE: float,
OBS_GPS_SIGNAL: float,
# Sensors
OBS_SIGNAL_LEVEL: float,
OBS_VIBRATION: float,
OBS_LAMP_STATE: float,
OBS_LAMP: float,
# IMU — accelerometer
OBS_ACCELEROMETER_X: float,
OBS_ACCELEROMETER_Y: float,
OBS_ACCELEROMETER_Z: float,
# IMU — gyroscope
OBS_GYROSCOPE_X: float,
OBS_GYROSCOPE_Y: float,
OBS_GYROSCOPE_Z: float,
# IMU — magnetometer
OBS_MAGNETOMETER_X: float,
OBS_MAGNETOMETER_Y: float,
OBS_MAGNETOMETER_Z: float,
# Wheel RPMs
OBS_WHEEL_RPM_0: float,
OBS_WHEEL_RPM_1: float,
OBS_WHEEL_RPM_2: float,
OBS_WHEEL_RPM_3: float,
}
@cached_property
@@ -189,8 +235,8 @@ class EarthRoverMiniPlus(Robot):
Returns:
dict: Action features with types:
- linear.vel: float - Target linear velocity
- angular.vel: float - Target angular velocity
- linear_velocity: float - Target linear velocity (-1 to 1)
- angular_velocity: float - Target angular velocity (-1 to 1)
"""
return {
ACTION_LINEAR_VEL: float,
@@ -201,19 +247,29 @@ class EarthRoverMiniPlus(Robot):
def get_observation(self) -> RobotObservation:
"""Get current robot observation from SDK.
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
Frames are decoded from base64 and converted from BGR to RGB format.
Robot telemetry is retrieved from /data endpoint.
Sensor arrays (accels, gyros, mags, rpms) each contain entries of
[values..., timestamp]; the latest reading from each array is used.
Returns:
RobotObservation: Observation containing:
- front: Front camera image (480, 640, 3) in RGB format
- rear: Rear camera image (480, 640, 3) in RGB format
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
- battery.level: Battery level (0-1, normalized from 0-100)
- orientation.deg: Robot orientation (0-1, normalized from raw value)
- gps.latitude: GPS latitude coordinate
- gps.longitude: GPS longitude coordinate
- gps.signal: GPS signal strength (0-1, normalized from percentage)
- signal.level: Network signal level (0-1, normalized from 0-5)
- vibration: Vibration sensor reading
- lamp.state: Lamp state (0=off, 1=on)
- speed: float - Current speed (raw SDK value)
- battery_level: float - Battery level (0-100)
- orientation: float - Robot orientation in degrees
- gps_latitude: float - GPS latitude coordinate
- gps_longitude: float - GPS longitude coordinate
- gps_signal: float - GPS signal strength (percentage)
- signal_level: float - Network signal level (0-5)
- vibration: float - Vibration sensor reading
- lamp: float - Lamp state (0=off, 1=on)
- accelerometer_x/y/z: float - Accelerometer axes (raw SDK value)
- gyroscope_x/y/z: float - Gyroscope axes (raw SDK value)
- magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value)
- wheel_rpm_0/1/2/3: float - Wheel RPMs
Raises:
DeviceNotConnectedError: If robot is not connected
@@ -235,22 +291,41 @@ class EarthRoverMiniPlus(Robot):
# Get robot state from SDK
robot_data = self._get_robot_data()
# Motion state
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
# Telemetry
observation[OBS_SPEED] = float(robot_data["speed"])
observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"])
observation[OBS_ORIENTATION] = float(robot_data["orientation"])
observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"])
observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"])
observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"])
observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"])
observation[OBS_VIBRATION] = float(robot_data["vibration"])
observation[OBS_LAMP] = float(robot_data["lamp"])
# Robot state
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
# Accelerometer — latest reading from accels array [x, y, z, ts]
accel = self._latest_sensor_reading(robot_data, "accels", n_values=3)
observation[OBS_ACCELEROMETER_X] = accel[0]
observation[OBS_ACCELEROMETER_Y] = accel[1]
observation[OBS_ACCELEROMETER_Z] = accel[2]
# GPS data
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
# Gyroscope — latest reading from gyros array [x, y, z, ts]
gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3)
observation[OBS_GYROSCOPE_X] = gyro[0]
observation[OBS_GYROSCOPE_Y] = gyro[1]
observation[OBS_GYROSCOPE_Z] = gyro[2]
# Sensors
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
observation[OBS_VIBRATION] = robot_data["vibration"]
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
# Magnetometer — latest reading from mags array [x, y, z, ts]
mag = self._latest_sensor_reading(robot_data, "mags", n_values=3)
observation[OBS_MAGNETOMETER_X] = mag[0]
observation[OBS_MAGNETOMETER_Y] = mag[1]
observation[OBS_MAGNETOMETER_Z] = mag[2]
# Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts]
rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4)
observation[OBS_WHEEL_RPM_0] = rpm[0]
observation[OBS_WHEEL_RPM_1] = rpm[1]
observation[OBS_WHEEL_RPM_2] = rpm[2]
observation[OBS_WHEEL_RPM_3] = rpm[3]
return observation
@@ -260,11 +335,12 @@ class EarthRoverMiniPlus(Robot):
Args:
action: Action dict with keys:
- linear.vel: Target linear velocity (-1 to 1)
- angular.vel: Target angular velocity (-1 to 1)
- linear_velocity: Target linear velocity (-1 to 1)
- angular_velocity: Target angular velocity (-1 to 1)
Returns:
RobotAction: The action that was sent (matches action_features keys)
Raises:
DeviceNotConnectedError: If robot is not connected
@@ -272,18 +348,14 @@ class EarthRoverMiniPlus(Robot):
Actions are sent to SDK via POST /control endpoint.
SDK expects commands in range [-1, 1].
"""
# Extract action values and convert to float
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
# Send command to SDK
try:
self._send_command_to_sdk(linear, angular)
except Exception as e:
logger.error(f"Error sending action: {e}")
# Return action in format matching action_features
return {
ACTION_LINEAR_VEL: linear,
ACTION_ANGULAR_VEL: angular,
@@ -394,11 +466,27 @@ class EarthRoverMiniPlus(Robot):
logger.error(f"Error decoding image: {e}")
return None
@staticmethod
def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]:
"""Extract the latest sensor reading from an SDK sensor array.
The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``,
``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``.
This helper returns the *n_values* leading floats from the last entry,
falling back to zeros when the key is missing or the array is empty.
"""
readings = robot_data.get(key)
if readings and len(readings) > 0:
latest = readings[-1]
return [float(v) for v in latest[:n_values]]
return [0.0] * n_values
def _get_robot_data(self) -> dict:
"""Get robot telemetry data from SDK.
Returns:
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
dict: Robot telemetry data including battery, speed, orientation, GPS,
and sensor arrays (accels, gyros, mags, rpms):
- Current data (if request succeeds)
- Cached data (if request fails but cache exists)
- Default values (if request fails and no cache exists yet)
@@ -420,19 +508,23 @@ class EarthRoverMiniPlus(Robot):
# Fallback: use cache or default values
if self._last_robot_data is not None:
return self._last_robot_data
else:
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
}
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
"accels": [],
"gyros": [],
"mags": [],
"rpms": [],
}
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
"""Send control command to SDK.

View File

@@ -18,7 +18,7 @@
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
remove features, modify tasks, and convert image datasets to video format.
remove features, modify tasks, recompute stats, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
Path semantics (v2): --root and --new_root are exact dataset folders containing
@@ -148,6 +148,21 @@ Show dataset information without feature details:
--operation.type info \
--operation.show_features false
Recompute dataset statistics:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type recompute_stats
Recompute stats for relative actions and push to hub:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--operation.num_workers 4 \
--push_to_hub true
Using JSON config file:
lerobot-edit-dataset \
--config_path path/to/edit_config.json
@@ -168,6 +183,7 @@ from lerobot.datasets.dataset_tools import (
delete_episodes,
merge_datasets,
modify_tasks,
recompute_stats,
remove_feature,
split_dataset,
)
@@ -230,6 +246,16 @@ class ConvertImageToVideoConfig(OperationConfig):
max_frames_per_batch: int | None = None
@OperationConfig.register_subclass("recompute_stats")
@dataclass
class RecomputeStatsConfig(OperationConfig):
skip_image_video: bool = True
relative_action: bool = False
relative_exclude_joints: list[str] | None = None
chunk_size: int = 50
num_workers: int = 0
@OperationConfig.register_subclass("info")
@dataclass
class InfoConfig(OperationConfig):
@@ -525,6 +551,35 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
logging.info("Dataset saved locally (not pushed to hub)")
def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, RecomputeStatsConfig):
raise ValueError("Operation config must be RecomputeStatsConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.info(f"Recomputing stats for {cfg.repo_id}")
if cfg.operation.relative_action:
logging.info(
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
f"exclude_joints={cfg.operation.relative_exclude_joints})"
)
recompute_stats(
dataset,
skip_image_video=cfg.operation.skip_image_video,
relative_action=cfg.operation.relative_action,
relative_exclude_joints=cfg.operation.relative_exclude_joints,
chunk_size=cfg.operation.chunk_size,
num_workers=cfg.operation.num_workers,
)
logging.info(f"Stats written to {dataset.root}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...")
dataset.push_to_hub()
def _get_dataset_size(repo_path):
import os
@@ -596,6 +651,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "recompute_stats":
handle_recompute_stats(cfg)
elif operation_type == "info":
handle_info(cfg)
else:

View File

@@ -47,6 +47,7 @@ You can learn about the CLI options for this script in the `EvalPipelineConfig`
"""
import concurrent.futures as cf
import copy
import json
import logging
import threading
@@ -56,7 +57,6 @@ from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from functools import partial
from pathlib import Path
from pprint import pformat
from typing import Any, TypedDict
@@ -73,7 +73,6 @@ from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
close_envs,
preprocess_observation,
@@ -166,9 +165,9 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Infer "task" from sub-environments.
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
observation["task"] = env.call("task")
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
@@ -734,34 +733,48 @@ def eval_policy_all(
group_acc[group]["video_paths"].extend(paths)
overall["video_paths"].extend(paths)
def _make_thread_policy(p: PreTrainedPolicy) -> PreTrainedPolicy:
"""Shallow copy sharing weight tensors, with independent per-thread state.
copy.copy() gives a new Python object whose _parameters dict is a shared
reference (same tensor storage, zero extra VRAM). reset() then rebinds
mutable state (action queues etc.) to fresh per-thread objects.
Note: does NOT work for ACT with temporal_ensemble_coeff — that policy's
reset() mutates a shared sub-object. Use max_parallel_tasks=1 for that config.
"""
thread_p = copy.copy(p)
thread_p.reset()
return thread_p
# Choose runner (sequential vs threaded)
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
)
_runner_kwargs = {
"env_preprocessor": env_preprocessor,
"env_postprocessor": env_postprocessor,
"preprocessor": preprocessor,
"postprocessor": postprocessor,
"n_episodes": n_episodes,
"max_episodes_rendered": max_episodes_rendered,
"videos_dir": videos_dir,
"return_episode_data": return_episode_data,
"start_seed": start_seed,
}
if max_parallel_tasks <= 1:
# sequential path (single accumulator path on the main thread)
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
for task_group, task_id, env in tasks:
tg, tid, metrics = task_runner(task_group, task_id, env)
tg, tid, metrics = run_one(task_group, task_id, env, policy=policy, **_runner_kwargs)
_accumulate_to(tg, metrics)
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
else:
# threaded path: submit all tasks, consume completions on main thread and accumulate there
# threaded path: each thread gets a shallow policy copy (shared weights, independent state)
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
fut2meta = {}
for task_group, task_id, env in tasks:
fut = executor.submit(task_runner, task_group, task_id, env)
fut = executor.submit(
run_one, task_group, task_id, env, policy=_make_thread_policy(policy), **_runner_kwargs
)
fut2meta[fut] = (task_group, task_id)
for fut in cf.as_completed(fut2meta):
tg, tid, metrics = fut.result()

View File

@@ -65,6 +65,7 @@ def get_sys_info() -> dict[str, str]:
"Platform": platform.platform(),
"Python version": platform.python_version(),
"Huggingface Hub version": get_package_version("huggingface_hub"),
"Transformers version": get_package_version("transformers"),
"Datasets version": get_package_version("datasets"),
"Numpy version": get_package_version("numpy"),
"FFmpeg version": get_ffmpeg_version(),

View File

@@ -74,6 +74,8 @@ from pathlib import Path
from pprint import pformat
from typing import Any
import torch
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
)
@@ -90,6 +92,7 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
@@ -226,6 +229,9 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
# Only applies when using a policy (not teleop)
interpolation_multiplier: int = 1
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -298,6 +304,7 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
interpolator: ActionInterpolator | None = None,
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -334,6 +341,16 @@ def record_loop(
preprocessor.reset()
postprocessor.reset()
# Reset interpolator if provided
if interpolator is not None:
interpolator.reset()
# Calculate control interval based on interpolation
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
# Pre-compute action key order outside the hot loop — it won't change mid-episode.
action_keys = sorted(robot.action_features) if use_interpolation else []
no_action_count = 0
timestamp = 0
start_episode_t = time.perf_counter()
@@ -353,28 +370,67 @@ def record_loop(
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Track whether this iteration should be recorded to the dataset.
# Interpolated-only iterations send actions to the robot but don't record frames,
# keeping the dataset at the original fps while the robot moves at the higher rate.
is_record_frame = True
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
# With interpolation: only call policy when interpolator needs new action
if use_interpolation:
ran_inference = False
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
if interpolator.needs_new_action():
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
interpolator.add(action_tensor)
ran_inference = True
interp_action = interpolator.get()
if interp_action is not None:
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
action_values = robot_action_to_send
else:
continue
is_record_frame = ran_inference
else:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
# Applies a pipeline to the action, default is IdentityProcessor
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
if robot.name == "unitree_g1":
teleop.send_feedback(obs)
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
elif policy is None and isinstance(teleop, list):
arm_action = teleop_arm.get_action()
@@ -383,6 +439,8 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
else:
no_action_count += 1
if no_action_count == 1 or no_action_count % 10 == 0:
@@ -393,22 +451,14 @@ def record_loop(
)
continue
# Applies a pipeline to the action, default is IdentityProcessor
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Write to dataset
if dataset is not None:
# Write to dataset (only on real policy frames, not interpolated-only iterations)
if dataset is not None and is_record_frame:
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
@@ -420,7 +470,7 @@ def record_loop(
dt_s = time.perf_counter() - start_loop_t
sleep_time_s: float = 1 / fps - dt_s
sleep_time_s: float = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
@@ -468,7 +518,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
try:
if cfg.resume:
dataset = LeRobotDataset(
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
@@ -476,13 +527,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
else:
# Create empty dataset or load existing saved episodes
@@ -507,6 +556,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
interpolator = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
@@ -517,6 +567,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Create interpolator for smoother policy control
if cfg.interpolation_multiplier > 1:
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
robot.connect()
if teleop is not None:
@@ -548,6 +602,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
interpolator=interpolator,
display_compressed_images=display_compressed_images,
)

View File

@@ -104,15 +104,13 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
actions = episode_frames.select_columns(ACTION)
actions = dataset.select_columns(ACTION)
robot.connect()
try:
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(len(episode_frames)):
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx][ACTION]

View File

@@ -252,10 +252,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
processor_pretrained_path = cfg.policy.pretrained_path
if (
getattr(cfg.policy, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
logging.warning(
"use_relative_actions=true with pretrained processors can skip relative transforms if "
"the checkpoint processors do not define them. Building processors from current policy config."
)
processor_pretrained_path = None
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
@@ -263,7 +275,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
if processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -285,7 +297,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)

View File

@@ -15,7 +15,7 @@
This script:
1. Loads action chunks from LeRobotDataset (with episode sampling)
2. Optionally applies delta transforms (relative vs absolute actions)
2. Optionally applies relative transforms (relative vs absolute actions)
3. Extracts specified action dimensions for encoding
4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes)
5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks
@@ -32,8 +32,8 @@ lerobot-train-tokenizer \
--max_episodes=100 \
--sample_fraction=0.1 \
--encoded_dims="0:6" \
--delta_dims="0,1,2,3,4,5" \
--use_delta_transform=true \
--relative_dims="0,1,2,3,4,5" \
--use_relative_transform=true \
--state_key="observation.state" \
--normalization_mode="QUANTILES" \
--vocab_size=1024 \
@@ -82,10 +82,10 @@ class TokenizerTrainingConfig:
sample_fraction: float = 0.1
# Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
encoded_dims: str = "0:6,7:23"
# Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
delta_dims: str | None = None
# Whether to apply delta transform (relative actions vs absolute actions)
use_delta_transform: bool = False
# Comma-separated dimension indices for relative transform (e.g., "0,1,2,3,4,5")
relative_dims: str | None = None
# Whether to apply relative transform (relative actions vs absolute actions)
use_relative_transform: bool = False
# Dataset key for state observations (default: "observation.state")
state_key: str = OBS_STATE
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
@@ -104,25 +104,27 @@ class TokenizerTrainingConfig:
hub_private: bool = False
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
"""Apply delta transform to specified dimensions.
def apply_relative_transform(
state: np.ndarray, actions: np.ndarray, relative_dims: list[int] | None
) -> np.ndarray:
"""Apply relative transform to specified dimensions.
Args:
state: Current state [D]
actions: Future actions [D]
delta_dims: List of dimension indices to apply delta transform to
relative_dims: List of dimension indices to apply relative transform to
Returns:
Transformed actions [D]
"""
if delta_dims is None or len(delta_dims) == 0:
if relative_dims is None or len(relative_dims) == 0:
return actions
delta_actions = actions.copy()
for dim in delta_dims:
delta_actions[dim] = actions[dim] - state[dim]
relative_actions = actions.copy()
for dim in relative_dims:
relative_actions[dim] = actions[dim] - state[dim]
return delta_actions
return relative_actions
def apply_normalization(
@@ -185,7 +187,7 @@ def apply_normalization(
def process_episode(args):
"""Process single episode and return action chunks."""
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
dataset, ep_idx, action_horizon, relative_dims, sample_fraction, state_key, use_relative_transform = args
try:
# get episode info
@@ -204,15 +206,15 @@ def process_episode(args):
for abs_idx in range(from_idx, to_idx):
# map absolute index to relative index if needed
if dataset._absolute_to_relative_idx is not None:
if abs_idx not in dataset._absolute_to_relative_idx:
if dataset.reader._absolute_to_relative_idx is not None:
if abs_idx not in dataset.reader._absolute_to_relative_idx:
# this episode's frames aren't in the filtered dataset
return None
rel_idx = dataset._absolute_to_relative_idx[abs_idx]
rel_idx = dataset.reader._absolute_to_relative_idx[abs_idx]
else:
rel_idx = abs_idx
frame = dataset.hf_dataset[rel_idx]
frame = dataset.get_raw_item(rel_idx)
# get state (could be from observation.state or other state key)
if state_key in frame:
@@ -222,7 +224,7 @@ def process_episode(args):
else np.array(frame[state_key])
)
else:
# if no state key, use zeros (no delta transform)
# if no state key, use zeros (no relative transform)
state = np.zeros_like(
frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
)
@@ -243,18 +245,18 @@ def process_episode(args):
current_state = states[i] # First state in chunk
future_absolute_actions = actions[i : i + action_horizon]
if use_delta_transform:
if use_relative_transform:
# relative actions
delta_chunk = np.zeros_like(future_absolute_actions)
relative_chunk = np.zeros_like(future_absolute_actions)
for t in range(action_horizon):
delta_chunk[t] = apply_delta_transform(
relative_chunk[t] = apply_relative_transform(
current_state,
future_absolute_actions[t],
delta_dims,
relative_dims,
)
action_chunks.append(delta_chunk)
action_chunks.append(relative_chunk)
else:
# absolute actions (no delta)
# absolute actions (no relative transform)
action_chunks.append(future_absolute_actions)
if len(action_chunks) == 0:
@@ -407,17 +409,20 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}")
# parse delta dimensions
delta_dim_list = None
if cfg.delta_dims is not None and cfg.delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")]
print(f"Delta dimensions: {delta_dim_list}")
# parse relative dimensions
relative_dim_list = None
if cfg.relative_dims is not None and cfg.relative_dims.strip():
relative_dim_list = [int(d.strip()) for d in cfg.relative_dims.split(",")]
print(f"Relative dimensions: {relative_dim_list}")
else:
print("No delta dimensions specified")
print("No relative dimensions specified")
print(f"Use delta transform: {cfg.use_delta_transform}")
if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
print(f"Use relative transform: {cfg.use_relative_transform}")
if cfg.use_relative_transform and (relative_dim_list is None or len(relative_dim_list) == 0):
print(
"Warning: use_relative_transform=True but no relative_dims specified. "
"No relative transform will be applied."
)
print(f"Action horizon: {cfg.action_horizon}")
print(f"State key: {cfg.state_key}")
@@ -440,10 +445,10 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
dataset,
ep_idx,
cfg.action_horizon,
delta_dim_list,
relative_dim_list,
cfg.sample_fraction,
cfg.state_key,
cfg.use_delta_transform,
cfg.use_relative_transform,
)
)
if chunks is not None:
@@ -544,9 +549,9 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
"encoded_dims": cfg.encoded_dims,
"encoded_dim_ranges": encoded_dim_ranges,
"total_encoded_dims": total_encoded_dims,
"delta_dims": cfg.delta_dims,
"delta_dim_list": delta_dim_list,
"use_delta_transform": cfg.use_delta_transform,
"relative_dims": cfg.relative_dims,
"relative_dim_list": relative_dim_list,
"use_relative_transform": cfg.use_relative_transform,
"state_key": cfg.state_key,
"normalization_mode": norm_mode.value,
"action_horizon": cfg.action_horizon,

View File

@@ -341,8 +341,8 @@ class KeyboardRoverTeleop(KeyboardTeleop):
def action_features(self) -> dict:
"""Return action format for rover (linear and angular velocities)."""
return {
"linear.vel": float,
"angular.vel": float,
"linear_velocity": float,
"angular_velocity": float,
}
@property
@@ -366,7 +366,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
Get the current action based on pressed keys.
Returns:
RobotAction with 'linear.vel' and 'angular.vel' keys
RobotAction with 'linear_velocity' and 'angular_velocity' keys.
"""
before_read_t = time.perf_counter()
@@ -427,6 +427,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
return {
"linear.vel": linear_velocity,
"angular.vel": angular_velocity,
"linear_velocity": linear_velocity,
"angular_velocity": angular_velocity,
}

View File

@@ -32,9 +32,15 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
# Leader joint 6 maps to follower joint 7 and vice versa
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator):
"""
@@ -95,6 +101,8 @@ class OpenArmMini(Teleoperator):
@property
def action_features(self) -> dict[str, type]:
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
@@ -276,16 +284,70 @@ class OpenArmMini(Teleoperator):
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
action: dict[str, Any] = {}
for motor, val in right_positions.items():
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def enable_torque(self) -> None:
"""Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None:
"""Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items():
if not key.endswith(".pos"):
continue
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
base = motor_name.removeprefix("right_")
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if right_goals:
self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")

View File

@@ -65,6 +65,10 @@ if "LEROBOT_HOME" in os.environ:
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/).
# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different
# dataset revisions are stored in isolated snapshot directories.
HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub"
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"

View File

@@ -95,6 +95,8 @@ def init_logging(
file_handler.setLevel(file_level.upper())
logger.addHandler(file_handler)
logging.getLogger("httpx").setLevel(logging.WARNING)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]

View File

@@ -80,7 +80,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
# indicating padding (those ending with "_is_pad")
dataset.delta_indices = None
dataset.reader.delta_indices = None
batch = next(iter(dataloader))
obs = {}
for k in batch:

View File

@@ -0,0 +1,385 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contract tests for LeRobotDatasetMetadata."""
import json
import numpy as np
import pytest
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.utils import INFO_PATH
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE
# ── helpers ──────────────────────────────────────────────────────────
SIMPLE_FEATURES = {
"state": {"dtype": "float32", "shape": (6,), "names": None},
"action": {"dtype": "float32", "shape": (6,), "names": None},
}
VIDEO_FEATURES = {
**SIMPLE_FEATURES,
"observation.images.laptop": {
"dtype": "video",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
IMAGE_FEATURES = {
**SIMPLE_FEATURES,
"observation.images.laptop": {
"dtype": "image",
"shape": (64, 96, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
def _make_dummy_stats(features: dict) -> dict:
"""Create minimal episode stats matching the given features."""
stats = {}
for key, ft in features.items():
if ft["dtype"] in ("image", "video"):
stats[key] = {
"max": np.ones((3, 1, 1), dtype=np.float32),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
"min": np.zeros((3, 1, 1), dtype=np.float32),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
"count": np.array([5]),
}
elif ft["dtype"] in ("float32", "float64", "int64"):
stats[key] = {
"max": np.ones(ft["shape"], dtype=np.float32),
"mean": np.full(ft["shape"], 0.5, dtype=np.float32),
"min": np.zeros(ft["shape"], dtype=np.float32),
"std": np.full(ft["shape"], 0.25, dtype=np.float32),
"count": np.array([5]),
}
return stats
# ── Construction contracts ───────────────────────────────────────────
def test_create_produces_valid_info_on_disk(tmp_path):
"""create() writes info.json and the returned object reflects the provided settings."""
root = tmp_path / "new_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/meta",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
robot_type=DUMMY_ROBOT_TYPE,
root=root,
use_videos=False,
)
# info.json was written to disk
assert (root / INFO_PATH).exists()
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert meta.fps == DEFAULT_FPS
assert meta.robot_type == DUMMY_ROBOT_TYPE
assert "state" in meta.features
assert "action" in meta.features
assert info_on_disk["fps"] == DEFAULT_FPS
def test_create_starts_with_zero_counts(tmp_path):
"""A freshly created metadata has zero episode/frame/task counts."""
root = tmp_path / "empty_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/empty", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
assert meta.total_episodes == 0
assert meta.total_frames == 0
assert meta.total_tasks == 0
assert meta.tasks is None
assert meta.episodes is None
assert meta.stats is None
def test_create_with_videos_sets_video_path(tmp_path):
"""When features include video-dtype keys, create() produces a non-None video_path."""
root = tmp_path / "video_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/video", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=root, use_videos=True
)
assert meta.video_path is not None
assert len(meta.video_keys) == 1
assert "observation.images.laptop" in meta.video_keys
def test_create_without_videos_has_no_video_path(tmp_path):
"""When use_videos=False and no video features, video_path is None."""
root = tmp_path / "no_video"
meta = LeRobotDatasetMetadata.create(
repo_id="test/novid", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
assert meta.video_path is None
assert meta.video_keys == []
def test_create_raises_on_existing_directory(tmp_path):
"""create() raises if root directory already exists."""
root = tmp_path / "existing"
root.mkdir()
with pytest.raises(FileExistsError):
LeRobotDatasetMetadata.create(
repo_id="test/exists", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory, info_factory):
"""When metadata files exist on disk, __init__ loads them correctly."""
root = tmp_path / "load_test"
info = info_factory(total_episodes=3, total_frames=150, total_tasks=1, use_videos=False)
meta = lerobot_dataset_metadata_factory(root=root, info=info)
assert meta.total_episodes == 3
assert meta.total_frames == 150
assert meta.fps == info["fps"]
# ── Property accessors ───────────────────────────────────────────────
def test_property_accessors_reflect_info(tmp_path):
"""Properties return values consistent with the info dict."""
root = tmp_path / "props_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/props",
fps=DEFAULT_FPS,
features=IMAGE_FEATURES,
robot_type=DUMMY_ROBOT_TYPE,
root=root,
use_videos=False,
)
assert meta.fps == DEFAULT_FPS
assert meta.robot_type == DUMMY_ROBOT_TYPE
# shapes should be tuples
for _key, shape in meta.shapes.items():
assert isinstance(shape, tuple)
# image_keys should contain the image feature
assert "observation.images.laptop" in meta.image_keys
# camera_keys is a superset of image_keys and video_keys
assert set(meta.image_keys + meta.video_keys) == set(meta.camera_keys)
def test_data_path_is_formattable(tmp_path):
"""data_path contains format placeholders that can be .format()-ed."""
root = tmp_path / "fmt_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/fmt", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
formatted = meta.data_path.format(chunk_index=0, file_index=0)
assert "chunk" in formatted.lower() or "0" in formatted
# ── Task management ──────────────────────────────────────────────────
def test_save_episode_tasks_creates_tasks_dataframe(tmp_path):
"""On a fresh metadata, save_episode_tasks() creates the tasks DataFrame."""
root = tmp_path / "task_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/task", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
assert meta.tasks is None
meta.save_episode_tasks(["Pick up the cube"])
assert meta.tasks is not None
assert len(meta.tasks) == 1
assert "Pick up the cube" in meta.tasks.index
def test_save_episode_tasks_is_additive(tmp_path):
"""New tasks are added; existing tasks keep their original index."""
root = tmp_path / "additive_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/add", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Task A"])
idx_a = meta.get_task_index("Task A")
meta.save_episode_tasks(["Task A", "Task B"])
assert meta.get_task_index("Task A") == idx_a # unchanged
assert meta.get_task_index("Task B") is not None
assert len(meta.tasks) == 2
def test_get_task_index_returns_none_for_unknown(tmp_path):
"""get_task_index() returns None for an unknown task."""
root = tmp_path / "unknown_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/unknown", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Known task"])
assert meta.get_task_index("Known task") == 0
assert meta.get_task_index("Unknown task") is None
def test_save_episode_tasks_rejects_duplicates(tmp_path):
"""save_episode_tasks() raises ValueError on duplicate task strings."""
root = tmp_path / "dup_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/dup", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
with pytest.raises(ValueError):
meta.save_episode_tasks(["Same task", "Same task"])
# ── Episode saving ───────────────────────────────────────────────────
def test_save_episode_increments_counters(tmp_path):
"""After save_episode(), total_episodes and total_frames increase."""
root = tmp_path / "ep_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/ep", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Task 1"])
stats = _make_dummy_stats(meta.features)
meta.save_episode(
episode_index=0,
episode_length=10,
episode_tasks=["Task 1"],
episode_stats=stats,
episode_metadata={},
)
assert meta.total_episodes == 1
assert meta.total_frames == 10
def test_save_episode_updates_stats(tmp_path):
"""After save_episode(), .stats is non-None and has feature keys."""
root = tmp_path / "stats_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/stats", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.save_episode_tasks(["Task 1"])
stats = _make_dummy_stats(meta.features)
meta.save_episode(
episode_index=0,
episode_length=5,
episode_tasks=["Task 1"],
episode_stats=stats,
episode_metadata={},
)
assert meta.stats is not None
# Stats should contain at least the user-defined feature keys
for key in SIMPLE_FEATURES:
assert key in meta.stats
# ── Chunk settings ───────────────────────────────────────────────────
def test_update_chunk_settings_persists(tmp_path):
"""update_chunk_settings() changes values and writes info.json."""
root = tmp_path / "chunk_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/chunk", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
original = meta.get_chunk_settings()
meta.update_chunk_settings(chunks_size=500)
assert meta.chunks_size == 500
assert meta.chunks_size != original["chunks_size"] or original["chunks_size"] == 500
# Verify persisted
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert info_on_disk["chunks_size"] == 500
def test_update_chunk_settings_rejects_non_positive(tmp_path):
"""update_chunk_settings() raises ValueError for <= 0 values."""
root = tmp_path / "bad_chunk"
meta = LeRobotDatasetMetadata.create(
repo_id="test/bad", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
with pytest.raises(ValueError):
meta.update_chunk_settings(chunks_size=0)
with pytest.raises(ValueError):
meta.update_chunk_settings(data_files_size_in_mb=-1)
# ── Finalization ─────────────────────────────────────────────────────
def test_finalize_is_idempotent(tmp_path):
"""Calling finalize() multiple times does not raise."""
root = tmp_path / "fin_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/fin", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
)
meta.finalize()
meta.finalize() # second call should not raise
def test_finalize_flushes_buffered_metadata(tmp_path):
"""Episodes saved before finalize() are written to parquet."""
root = tmp_path / "flush_ds"
meta = LeRobotDatasetMetadata.create(
repo_id="test/flush",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
metadata_buffer_size=100, # large buffer so nothing auto-flushes
)
meta.save_episode_tasks(["Task 1"])
stats = _make_dummy_stats(meta.features)
# Save a few episodes (won't auto-flush since buffer_size=100)
for i in range(3):
meta.save_episode(
episode_index=i,
episode_length=5,
episode_tasks=["Task 1"],
episode_stats=stats,
episode_metadata={},
)
# Before finalize, the parquet might not exist yet
meta.finalize()
# After finalize, episodes parquet should exist
episodes_dir = root / "meta" / "episodes"
assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contract tests for DatasetReader."""
from lerobot.datasets.dataset_reader import DatasetReader
from lerobot.datasets.video_utils import get_safe_default_codec
# ── Loading ──────────────────────────────────────────────────────────
def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factory):
"""Given a fully written dataset, try_load() returns True."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
)
reader = DatasetReader(
meta=dataset.meta,
root=dataset.root,
episodes=None,
tolerance_s=1e-4,
video_backend=get_safe_default_codec(),
delta_timestamps=None,
image_transforms=None,
)
assert reader.try_load() is True
assert reader.hf_dataset is not None
def test_try_load_returns_false_when_no_data(tmp_path):
"""When only metadata exists (no data/ parquets), try_load() returns False."""
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
root = tmp_path / "meta_only"
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
meta = LeRobotDatasetMetadata.create(
repo_id="test/meta_only", fps=30, features=features, root=root, use_videos=False
)
reader = DatasetReader(
meta=meta,
root=meta.root,
episodes=None,
tolerance_s=1e-4,
video_backend=get_safe_default_codec(),
delta_timestamps=None,
image_transforms=None,
)
assert reader.try_load() is False
assert reader.hf_dataset is None
# ── Counts ───────────────────────────────────────────────────────────
def test_num_frames_without_filter(tmp_path, lerobot_dataset_factory):
"""With episodes=None, num_frames equals total_frames."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False
)
assert dataset.reader.num_frames == dataset.meta.total_frames
def test_num_episodes_without_filter(tmp_path, lerobot_dataset_factory):
"""With episodes=None, num_episodes equals total_episodes."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False
)
assert dataset.reader.num_episodes == dataset.meta.total_episodes
def test_num_frames_with_episode_filter(tmp_path, lerobot_dataset_factory):
"""When filtering to a subset, only those episodes' frames are counted."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=5, total_frames=100, episodes=[0, 2], use_videos=False
)
# Filtered frames should be less than total
assert dataset.reader.num_frames <= dataset.meta.total_frames
assert dataset.reader.num_episodes == 2
# ── get_item ─────────────────────────────────────────────────────────
def test_get_item_returns_expected_keys(tmp_path, lerobot_dataset_factory):
"""get_item(0) returns a dict with expected keys."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
)
item = dataset.reader.get_item(0)
# Standard keys that must always be present
for key in ["index", "episode_index", "frame_index", "timestamp", "task_index", "task"]:
assert key in item, f"Missing key: {key}"
def test_get_item_values_are_correct(tmp_path, lerobot_dataset_factory):
"""get_item() returns correct index and episode_index."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
)
item_0 = dataset.reader.get_item(0)
assert item_0["index"].item() == 0
assert item_0["episode_index"].item() == 0
# ── Transforms ───────────────────────────────────────────────────────
def test_image_transforms_are_applied(tmp_path, lerobot_dataset_factory):
"""When image_transforms is provided, get_item() applies it to camera keys."""
transform_called = {"count": 0}
def sentinel_transform(img):
transform_called["count"] += 1
return img
dataset = lerobot_dataset_factory(
root=tmp_path / "ds",
total_episodes=1,
total_frames=5,
use_videos=False,
image_transforms=sentinel_transform,
)
item = dataset[0] # noqa: F841
# Should have been called once per camera key per frame
num_cameras = len(dataset.meta.camera_keys)
if num_cameras > 0:
assert transform_called["count"] >= 1
# ── File paths ───────────────────────────────────────────────────────
def test_get_episodes_file_paths_returns_data_paths(tmp_path, lerobot_dataset_factory):
"""get_episodes_file_paths() returns paths including data/ paths."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
)
paths = dataset.reader.get_episodes_file_paths()
assert len(paths) > 0
assert any("data/" in str(p) for p in paths)
def test_get_episodes_file_paths_includes_video_paths(tmp_path, lerobot_dataset_factory):
"""When dataset has video keys, file paths include video/ paths."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=True
)
if len(dataset.meta.video_keys) > 0:
paths = dataset.reader.get_episodes_file_paths()
assert any("video" in str(p).lower() for p in paths)

View File

@@ -0,0 +1,226 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contract tests for DatasetWriter."""
from pathlib import Path
from unittest.mock import patch
import numpy as np
import pytest
import torch
from PIL import Image
from lerobot.datasets.dataset_writer import _encode_video_worker
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
SIMPLE_FEATURES = {
"state": {"dtype": "float32", "shape": (6,), "names": None},
"action": {"dtype": "float32", "shape": (6,), "names": None},
}
def _make_frame(features: dict, task: str = "Dummy task") -> dict:
"""Build a valid frame dict for the given features."""
frame = {"task": task}
for key, ft in features.items():
if ft["dtype"] in ("image", "video"):
frame[key] = np.random.randint(0, 256, size=ft["shape"], dtype=np.uint8)
elif ft["dtype"] in ("float32", "float64"):
frame[key] = torch.randn(ft["shape"])
elif ft["dtype"] == "int64":
frame[key] = torch.zeros(ft["shape"], dtype=torch.int64)
return frame
# ── Existing encode_video_worker tests ───────────────────────────────
def test_encode_video_worker_forwards_vcodec(tmp_path):
"""_encode_video_worker correctly forwards the vcodec parameter."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png")
captured_kwargs = {}
def mock_encode(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264")
assert captured_kwargs["vcodec"] == "h264"
def test_encode_video_worker_default_vcodec(tmp_path):
"""_encode_video_worker uses libsvtav1 as the default codec."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png")
captured_kwargs = {}
def mock_encode(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30)
assert captured_kwargs["vcodec"] == "libsvtav1"
# ── add_frame contracts ──────────────────────────────────────────────
def test_add_frame_increments_buffer_size(tmp_path):
"""Each add_frame() call increases episode_buffer['size'] by 1."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
assert dataset.writer.episode_buffer["size"] == 0
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
assert dataset.writer.episode_buffer["size"] == 1
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
assert dataset.writer.episode_buffer["size"] == 2
def test_add_frame_rejects_missing_feature(tmp_path):
"""add_frame() raises ValueError when a required feature is missing."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
with pytest.raises(ValueError, match="Missing features"):
dataset.add_frame({"task": "Dummy task", "state": torch.randn(6)})
# missing 'action'
# ── save_episode contracts ───────────────────────────────────────────
def test_save_episode_writes_parquet(tmp_path):
"""After save_episode(), at least one .parquet file exists under data/."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
for _ in range(3):
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
dataset.save_episode()
parquet_files = list((tmp_path / "ds" / "data").rglob("*.parquet"))
assert len(parquet_files) > 0
def test_save_episode_updates_counters(tmp_path):
"""After save_episode(), metadata counters are updated."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
for _ in range(5):
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
dataset.save_episode()
assert dataset.meta.total_episodes == 1
assert dataset.meta.total_frames == 5
def test_save_episode_resets_buffer(tmp_path):
"""After save_episode(), the episode buffer is reset."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
for _ in range(3):
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
dataset.save_episode()
assert dataset.writer.episode_buffer["size"] == 0
def test_save_multiple_episodes(tmp_path):
"""Recording 3 episodes results in correct total counts."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
total_frames = 0
for ep in range(3):
n_frames = ep + 2 # 2, 3, 4
for _ in range(n_frames):
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
dataset.save_episode()
total_frames += n_frames
assert dataset.meta.total_episodes == 3
assert dataset.meta.total_frames == total_frames
# ── clear / lifecycle ────────────────────────────────────────────────
def test_clear_resets_buffer(tmp_path):
"""clear_episode_buffer() resets the buffer size to 0."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
assert dataset.writer.episode_buffer["size"] == 1
dataset.clear_episode_buffer()
assert dataset.writer.episode_buffer["size"] == 0
def test_finalize_is_idempotent(tmp_path):
"""Calling finalize() twice does not raise."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
for _ in range(3):
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
dataset.save_episode()
dataset.finalize()
dataset.finalize() # second call should not raise
def test_finalize_then_read_roundtrip(tmp_path):
"""Write data, finalize, re-open, and verify data matches."""
root = tmp_path / "roundtrip"
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=root)
# Record known values
known_states = []
for i in range(5):
state = torch.tensor([float(i), float(i * 10)])
known_states.append(state)
dataset.add_frame({"task": "Test task", "state": state})
dataset.save_episode()
dataset.finalize()
# Read back
for i in range(5):
item = dataset[i]
assert torch.allclose(item["state"], known_states[i], atol=1e-5)

View File

@@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
_encode_video_worker,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
@@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
objects have the same sets of facade-level attributes defined.
"""
# Instantiate both ways
robot = make_robot_from_config(MockRobotConfig())
@@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
# Facade-level attributes should match between __init__ and create()
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
@@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert len(dataset) == 1
assert dataset[0]["task"] == "Dummy task"
@@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2])
@@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].ndim == 0
@@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["caption"] == "Dummy caption"
@@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset):
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset):
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
ds_img.save_episode()
img_dir = ds_img._get_image_file_dir(0, image_key)
img_dir = ds_img.writer._get_image_file_dir(0, image_key)
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
@@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
}
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
ds_vid.batch_encoding_size = 1
ds_vid.writer._batch_encoding_size = 1
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
ds_vid.save_episode()
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key)
assert not vid_img_dir.exists(), (
"Temporary image directory should be removed when batch_encoding_size == 1"
)
@@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
}
)
ds_mixed.save_episode()
img_dir = ds_mixed._get_image_file_dir(0, image_key)
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
img_dir = ds_mixed.writer._get_image_file_dir(0, image_key)
vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key)
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
assert vid_img_dir.exists(), (
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
@@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
)
# Test hf_dataset is None
dataset.hf_dataset = None
assert dataset._check_cached_episodes_sufficient() is False
dataset.reader.hf_dataset = None
assert dataset.reader._check_cached_episodes_sufficient() is False
# Test hf_dataset is empty
import datasets
empty_features = get_hf_features_from_features(dataset.features)
dataset.hf_dataset = datasets.Dataset.from_dict(
dataset.reader.hf_dataset = datasets.Dataset.from_dict(
{key: [] for key in empty_features}, features=empty_features
)
dataset.hf_dataset.set_transform(hf_transform_to_torch)
assert dataset._check_cached_episodes_sufficient() is False
dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
assert dataset.reader._check_cached_episodes_sufficient() is False
# Restore the original dataset for remaining tests
dataset.hf_dataset = dataset.load_hf_dataset()
dataset.reader.hf_dataset = dataset.reader._load_hf_dataset()
# Test all episodes requested (self.episodes = None) and all are available
dataset.episodes = None
assert dataset._check_cached_episodes_sufficient() is True
dataset.reader.episodes = None
assert dataset.reader._check_cached_episodes_sufficient() is True
# Test specific episodes requested that are all available
dataset.episodes = [0, 2, 4]
assert dataset._check_cached_episodes_sufficient() is True
dataset.reader.episodes = [0, 2, 4]
assert dataset.reader._check_cached_episodes_sufficient() is True
# Test request episodes that don't exist in the cached dataset
# Create a dataset with only episodes 0, 1, 2
@@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
)
# Request episodes that include non-existent ones
limited_dataset.episodes = [0, 1, 2, 3, 4]
assert limited_dataset._check_cached_episodes_sufficient() is False
limited_dataset.reader.episodes = [0, 1, 2, 3, 4]
assert limited_dataset.reader._check_cached_episodes_sufficient() is False
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
# First create the full dataset structure
@@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
filtered_data[key] = filtered_values
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict(
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
)
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
# Test requesting all episodes when only some are cached
sparse_dataset.episodes = None
assert sparse_dataset._check_cached_episodes_sufficient() is False
sparse_dataset.reader.episodes = None
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
# Test requesting only the available episodes
sparse_dataset.episodes = [0, 2, 4]
assert sparse_dataset._check_cached_episodes_sufficient() is True
sparse_dataset.reader.episodes = [0, 2, 4]
assert sparse_dataset.reader._check_cached_episodes_sufficient() is True
# Test requesting a mix of available and unavailable episodes
sparse_dataset.episodes = [0, 1, 2]
assert sparse_dataset._check_cached_episodes_sufficient() is False
sparse_dataset.reader.episodes = [0, 1, 2]
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
@@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
del dataset_verify
# Phase 3: Resume recording - add more episodes
dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0")
assert dataset_resumed.meta.total_episodes == initial_episodes
assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
assert dataset_resumed.latest_episode is None # Not recording yet
assert dataset_resumed.writer is None
assert dataset_resumed.meta.writer is None
assert dataset_resumed.writer._latest_episode is None # Not recording yet
assert dataset_resumed.writer._pq_writer is None
assert dataset_resumed.meta._pq_writer is None
additional_episodes = 2
for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
@@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
assert dataset._current_file_start_frame is None
assert dataset.writer._current_file_start_frame is None
frames_per_episode = 10
for _ in range(frames_per_episode):
@@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.writer._current_file_start_frame == 0
assert dataset.meta.total_episodes == 1
assert dataset.meta.total_frames == frames_per_episode
@@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.writer._current_file_start_frame == 0
assert dataset.meta.total_episodes == 2
assert dataset.meta.total_frames == 2 * frames_per_episode
ep1_chunk = dataset.latest_episode["data/chunk_index"]
ep1_file = dataset.latest_episode["data/file_index"]
ep1_chunk = dataset.writer._latest_episode["data/chunk_index"]
ep1_file = dataset.writer._latest_episode["data/file_index"]
assert ep1_chunk == 0
assert ep1_file == 0
@@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.writer._current_file_start_frame == 0
assert dataset.meta.total_episodes == 3
assert dataset.meta.total_frames == 3 * frames_per_episode
ep2_chunk = dataset.latest_episode["data/chunk_index"]
ep2_file = dataset.latest_episode["data/file_index"]
ep2_chunk = dataset.writer._latest_episode["data/chunk_index"]
ep2_file = dataset.writer._latest_episode["data/file_index"]
assert ep2_chunk == 0
assert ep2_file == 0
@@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
assert frame["episode_index"].item() == expected_ep
def test_encode_video_worker_forwards_vcodec(tmp_path):
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
from unittest.mock import patch
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
# Create the expected directory structure
video_key = "observation.images.laptop"
episode_index = 0
frame_index = 0
fpath = DEFAULT_IMAGE_PATH.format(
image_key=video_key, episode_index=episode_index, frame_index=frame_index
)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
# Create a dummy image file
dummy_img = Image.new("RGB", (64, 64), color="red")
dummy_img.save(img_dir / "frame-000000.png")
# Track what vcodec was passed to encode_video_frames
captured_kwargs = {}
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
# Create a dummy output file so the worker doesn't fail
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
# Test with h264 codec
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
assert "vcodec" in captured_kwargs
assert captured_kwargs["vcodec"] == "h264"
def test_encode_video_worker_default_vcodec(tmp_path):
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
from unittest.mock import patch
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
# Create the expected directory structure
video_key = "observation.images.laptop"
episode_index = 0
frame_index = 0
fpath = DEFAULT_IMAGE_PATH.format(
image_key=video_key, episode_index=episode_index, frame_index=frame_index
)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
# Create a dummy image file
dummy_img = Image.new("RGB", (64, 64), color="red")
dummy_img.save(img_dir / "frame-000000.png")
# Track what vcodec was passed to encode_video_frames
captured_kwargs = {}
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
# Create a dummy output file so the worker doesn't fail
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
# Test with default codec (no vcodec specified)
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
assert "vcodec" in captured_kwargs
assert captured_kwargs["vcodec"] == "libsvtav1"
def test_lerobot_dataset_vcodec_validation():
"""Test that LeRobotDataset validates the vcodec parameter."""
# Test that invalid vcodec raises ValueError

View File

@@ -352,10 +352,14 @@ def test_with_different_image_formats(tmp_path, img_array_factory):
def test_safe_stop_image_writer_decorator():
class MockDataset:
class MockWriter:
def __init__(self):
self.image_writer = MagicMock(spec=AsyncImageWriter)
class MockDataset:
def __init__(self):
self.writer = MockWriter()
@safe_stop_image_writer
def function_that_raises_exception(dataset=None):
raise Exception("Test exception")
@@ -366,7 +370,7 @@ def test_safe_stop_image_writer_decorator():
function_that_raises_exception(dataset=dataset)
assert str(exc_info.value) == "Test exception"
dataset.image_writer.stop.assert_called_once()
dataset.writer.image_writer.stop.assert_called_once()
def test_main_process_time(tmp_path, img_tensor_factory):

View File

@@ -0,0 +1,632 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contract tests for the LeRobotDataset facade.
Tests focus on mode contracts (read-only, write-only, resume), guards,
property delegation, and the full create-record-finalize-read lifecycle.
"""
from pathlib import Path
from unittest.mock import Mock
import pytest
import torch
import lerobot.datasets.dataset_metadata as dataset_metadata_module
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.dataset_reader import DatasetReader
from lerobot.datasets.dataset_writer import DatasetWriter
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
SIMPLE_FEATURES = {
"state": {"dtype": "float32", "shape": (2,), "names": None},
}
SNAPSHOT_MAIN_FEATURES = {
**SIMPLE_FEATURES,
"test": {"dtype": "float32", "shape": (2,), "names": None},
}
def _make_frame(task: str = "Dummy task") -> dict:
return {"task": task, "state": torch.randn(2)}
def _set_default_cache_root(monkeypatch: pytest.MonkeyPatch, cache_root: Path) -> None:
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HOME", cache_root)
monkeypatch.setattr(dataset_metadata_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
monkeypatch.setattr(lerobot_dataset_module, "HF_LEROBOT_HUB_CACHE", cache_root / "hub")
def _write_dataset_tree(
root: Path,
*,
motor_features: dict[str, dict],
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
) -> None:
root.mkdir(parents=True, exist_ok=True)
info = info_factory(
total_episodes=1,
total_frames=3,
total_tasks=1,
use_videos=False,
motor_features=motor_features,
camera_features={},
)
tasks = tasks_factory(total_tasks=1)
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=1,
total_frames=3,
tasks=tasks,
)
stats = stats_factory(features=info["features"])
hf_dataset = hf_dataset_factory(
features=info["features"],
tasks=tasks,
episodes=episodes,
fps=info["fps"],
)
create_info(root, info)
create_stats(root, stats)
create_tasks(root, tasks)
create_episodes(root, episodes)
create_hf_dataset(root, hf_dataset)
# ── Read-only mode (via __init__) ────────────────────────────────────
def test_init_creates_reader_no_writer(tmp_path, lerobot_dataset_factory):
"""__init__() sets reader to a DatasetReader and writer to None."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
)
assert isinstance(dataset.reader, DatasetReader)
assert dataset.writer is None
def test_init_loads_data(tmp_path, lerobot_dataset_factory):
"""After __init__(), the dataset has data and len > 0."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
)
assert len(dataset) > 0
def test_getitem_works_in_read_mode(tmp_path, lerobot_dataset_factory):
"""dataset[0] returns a dict with expected keys in read-only mode."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
)
item = dataset[0]
assert isinstance(item, dict)
assert "index" in item
assert "task" in item
def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory):
"""len(dataset) equals dataset.num_frames."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=2, total_frames=30, use_videos=False
)
assert len(dataset) == dataset.num_frames
def test_metadata_without_root_uses_hub_cache_snapshot_download(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Metadata refresh uses the dedicated Hub cache instead of a shared local_dir mirror."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
_write_dataset_tree(
snapshot_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main", force_cache_sync=True)
assert meta.root == snapshot_root
assert snapshot_download.call_count == 1
assert snapshot_download.call_args.args == (repo_id,)
assert snapshot_download.call_args.kwargs == {
"repo_type": "dataset",
"revision": "main",
"cache_dir": cache_root / "hub",
"allow_patterns": "meta/",
"ignore_patterns": None,
}
def test_without_root_reads_different_revisions_from_distinct_snapshot_roots(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Different revisions resolve to different on-disk snapshot roots."""
repo_id = DUMMY_REPO_ID
old_revision = "b59010db93eb6cc3cf06ef2f7cae1bbe62b726d9"
cache_root = tmp_path / "lerobot_cache"
main_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
old_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-old"
_write_dataset_tree(
main_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_write_dataset_tree(
old_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_roots = {
"main": main_root,
old_revision: old_root,
}
meta_snapshot_download = Mock(
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
)
data_snapshot_download = Mock(
side_effect=lambda repo_id, **kwargs: str(snapshot_roots[kwargs["revision"]])
)
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
main_dataset = LeRobotDataset(
repo_id=repo_id, revision="main", download_videos=False, force_cache_sync=True
)
old_dataset = LeRobotDataset(
repo_id=repo_id, revision=old_revision, download_videos=False, force_cache_sync=True
)
assert main_dataset.root == main_root
assert old_dataset.root == old_root
assert "test" in main_dataset.hf_dataset.column_names
assert "test" not in old_dataset.hf_dataset.column_names
# Metadata downloads use cache_dir, not local_dir
assert meta_snapshot_download.call_count == 2
for download_call in meta_snapshot_download.call_args_list:
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in download_call.kwargs
# Data downloads also use cache_dir, not local_dir
assert data_snapshot_download.call_count == 2
for download_call in data_snapshot_download.call_args_list:
assert download_call.kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in download_call.kwargs
def test_metadata_without_root_ignores_legacy_local_dir_cache(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""Legacy local-dir mirrors are bypassed in favor of revision-safe snapshots."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
legacy_root = cache_root / repo_id
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
_write_dataset_tree(
legacy_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
(legacy_root / ".cache" / "huggingface" / "download").mkdir(parents=True, exist_ok=True)
_write_dataset_tree(
snapshot_root,
motor_features=SNAPSHOT_MAIN_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", snapshot_download)
meta = LeRobotDatasetMetadata(repo_id=repo_id, revision="main")
assert meta.root == snapshot_root
assert "test" in meta.features
assert snapshot_download.call_count == 1
def test_download_without_root_uses_hub_cache(
tmp_path,
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
create_info,
create_stats,
create_tasks,
create_episodes,
create_hf_dataset,
monkeypatch,
):
"""LeRobotDataset._download() uses cache_dir (not local_dir) when root is not provided."""
repo_id = DUMMY_REPO_ID
cache_root = tmp_path / "lerobot_cache"
snapshot_root = cache_root / "hub" / "datasets--dummy--repo" / "snapshots" / "commit-main"
# Pre-populate snapshot directory so metadata loads succeed, but leave
# data absent so that _download() is triggered.
_write_dataset_tree(
snapshot_root,
motor_features=SIMPLE_FEATURES,
info_factory=info_factory,
stats_factory=stats_factory,
tasks_factory=tasks_factory,
episodes_factory=episodes_factory,
hf_dataset_factory=hf_dataset_factory,
create_info=create_info,
create_stats=create_stats,
create_tasks=create_tasks,
create_episodes=create_episodes,
create_hf_dataset=create_hf_dataset,
)
_set_default_cache_root(monkeypatch, cache_root)
meta_snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(dataset_metadata_module, "snapshot_download", meta_snapshot_download)
# Mock the data snapshot_download to return the same root (data already
# exists there from _write_dataset_tree).
data_snapshot_download = Mock(return_value=str(snapshot_root))
monkeypatch.setattr(lerobot_dataset_module, "snapshot_download", data_snapshot_download)
LeRobotDataset(repo_id=repo_id, revision="main", force_cache_sync=True)
# _download() should have called snapshot_download with cache_dir
assert data_snapshot_download.call_count == 1
call_kwargs = data_snapshot_download.call_args.kwargs
assert call_kwargs["cache_dir"] == cache_root / "hub"
assert "local_dir" not in call_kwargs
# ── Write-only mode (via create()) ──────────────────────────────────
def test_create_sets_writer_no_reader(tmp_path):
"""create() sets writer to a DatasetWriter and reader to None."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
assert isinstance(dataset.writer, DatasetWriter)
assert dataset.reader is None
def test_create_initial_counts_zero(tmp_path):
"""After create(), num_episodes == 0 and num_frames == 0."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
assert dataset.num_episodes == 0
assert dataset.num_frames == 0
def test_add_frame_works_in_write_mode(tmp_path):
"""add_frame() succeeds on a dataset created via create()."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
dataset.add_frame(_make_frame()) # should not raise
# ── Resume mode ──────────────────────────────────────────────────────
def test_resume_creates_writer(tmp_path):
"""After resume(), writer is a DatasetWriter."""
root = tmp_path / "resume_ds"
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
)
for _ in range(3):
dataset.add_frame(_make_frame())
dataset.save_episode()
dataset.finalize()
resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root)
assert isinstance(resumed.writer, DatasetWriter)
def test_resume_preserves_episode_count(tmp_path):
"""After resume(), existing episodes are counted."""
root = tmp_path / "resume_ds"
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
)
for _ in range(3):
dataset.add_frame(_make_frame())
dataset.save_episode()
dataset.finalize()
resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root)
assert resumed.meta.total_episodes == 1
def test_resume_can_add_more_episodes(tmp_path):
"""After resume(), new episodes can be added."""
root = tmp_path / "resume_ds"
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
)
for _ in range(3):
dataset.add_frame(_make_frame())
dataset.save_episode()
dataset.finalize()
resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root)
for _ in range(2):
resumed.add_frame(_make_frame())
resumed.save_episode()
assert resumed.meta.total_episodes == 2
# ── Writer guard ─────────────────────────────────────────────────────
def test_add_frame_raises_without_writer(tmp_path, lerobot_dataset_factory):
"""add_frame() raises RuntimeError on a read-only dataset."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
)
with pytest.raises(RuntimeError, match="read-only"):
dataset.add_frame(_make_frame())
def test_save_episode_raises_without_writer(tmp_path, lerobot_dataset_factory):
"""save_episode() raises RuntimeError on a read-only dataset."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
)
with pytest.raises(RuntimeError, match="read-only"):
dataset.save_episode()
def test_clear_episode_buffer_raises_without_writer(tmp_path, lerobot_dataset_factory):
"""clear_episode_buffer() raises RuntimeError on a read-only dataset."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
)
with pytest.raises(RuntimeError, match="read-only"):
dataset.clear_episode_buffer()
# ── Reader guard ─────────────────────────────────────────────────────
def test_getitem_raises_before_finalize(tmp_path):
"""dataset[0] raises RuntimeError while recording (before finalize)."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
for _ in range(3):
dataset.add_frame(_make_frame())
dataset.save_episode()
with pytest.raises(RuntimeError, match="finalize"):
dataset[0]
def test_getitem_works_after_finalize(tmp_path):
"""After finalize(), dataset[0] returns data."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
for _ in range(3):
dataset.add_frame(_make_frame())
dataset.save_episode()
dataset.finalize()
item = dataset[0]
assert "state" in item
assert "task" in item
# ── Property delegation ──────────────────────────────────────────────
def test_fps_delegates_to_meta(tmp_path, lerobot_dataset_factory):
"""dataset.fps == dataset.meta.fps."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
)
assert dataset.fps == dataset.meta.fps
def test_features_delegates_to_meta(tmp_path, lerobot_dataset_factory):
"""dataset.features is dataset.meta.features."""
dataset = lerobot_dataset_factory(
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
)
assert dataset.features is dataset.meta.features
def test_num_frames_uses_meta_in_write_mode(tmp_path):
"""In write-only mode (reader=None), num_frames comes from metadata."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
assert dataset.reader is None
assert dataset.num_frames == dataset.meta.total_frames
# ── Lifecycle ────────────────────────────────────────────────────────
def test_finalize_is_idempotent(tmp_path):
"""Calling finalize() twice does not raise."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
dataset.finalize()
dataset.finalize()
def test_has_pending_frames_lifecycle(tmp_path):
"""has_pending_frames: False -> True (add_frame) -> False (save_episode)."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
)
assert dataset.has_pending_frames() is False
dataset.add_frame(_make_frame())
assert dataset.has_pending_frames() is True
dataset.save_episode()
assert dataset.has_pending_frames() is False
def test_create_record_finalize_read_roundtrip(tmp_path):
"""End-to-end: create, record 2 episodes, finalize, re-open, verify data."""
root = tmp_path / "roundtrip"
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
)
# Episode 0: 3 frames with known values
ep0_states = []
for i in range(3):
state = torch.tensor([float(i), float(i * 2)])
ep0_states.append(state)
dataset.add_frame({"task": "Task A", "state": state})
dataset.save_episode()
# Episode 1: 2 frames
ep1_states = []
for i in range(2):
state = torch.tensor([float(i + 100), float(i + 200)])
ep1_states.append(state)
dataset.add_frame({"task": "Task B", "state": state})
dataset.save_episode()
dataset.finalize()
# Re-open as read-only
reopened = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root)
assert len(reopened) == 5
assert reopened.num_episodes == 2
# Verify episode 0
for i in range(3):
item = reopened[i]
assert torch.allclose(item["state"], ep0_states[i], atol=1e-5)
assert item["episode_index"].item() == 0
# Verify episode 1
for i in range(2):
item = reopened[3 + i]
assert torch.allclose(item["state"], ep1_states[i], atol=1e-5)
assert item["episode_index"].item() == 1

View File

@@ -534,7 +534,7 @@ class TestStreamingEncoderIntegration:
streaming_encoding=True,
)
assert dataset._streaming_encoder is not None
assert dataset.writer._streaming_encoder is not None
num_frames = 20
for _ in range(num_frames):
@@ -580,7 +580,7 @@ class TestStreamingEncoderIntegration:
streaming_encoding=False,
)
assert dataset._streaming_encoder is None
assert dataset.writer._streaming_encoder is None
num_frames = 5
for _ in range(num_frames):

143
tests/envs/test_dispatch.py Normal file
View File

@@ -0,0 +1,143 @@
"""Tests for the benchmark dispatch refactor (create_envs / get_env_processors on EnvConfig)."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
import gymnasium as gym
import pytest
from gymnasium.envs.registration import register, registry as gym_registry
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
logger = logging.getLogger(__name__)
def test_registry_all_types():
"""make_env_config should resolve every registered EnvConfig subclass via the registry."""
known = list(EnvConfig.get_known_choices().keys())
assert len(known) >= 6
for t in known:
cfg = make_env_config(t)
assert cfg.type == t
def test_unknown_type():
with pytest.raises(ValueError, match="not registered"):
make_env_config("nonexistent")
def test_identity_processors():
"""Base class get_env_processors() returns identity pipelines."""
cfg = make_env_config("aloha")
pre, post = cfg.get_env_processors()
assert len(pre.steps) == 0 and len(post.steps) == 0
def test_delegation():
"""make_env() should call cfg.create_envs(), not use if/elif dispatch."""
sentinel = {"delegated": {0: "marker"}}
fake = type(
"Fake",
(),
{
"hub_path": None,
"create_envs": lambda self, n_envs, use_async_envs=False: sentinel,
},
)()
result = make_env(fake, n_envs=1)
assert result is sentinel
def test_processors_delegation():
"""make_env_pre_post_processors delegates to cfg.get_env_processors()."""
from lerobot.configs.policies import PreTrainedConfig
cfg = make_env_config("aloha")
pre, post = make_env_pre_post_processors(cfg, PreTrainedConfig())
assert len(pre.steps) == 0
def test_base_create_envs():
"""Base class create_envs() should build a single-task VectorEnv via gym.make()."""
gym_id = "_dispatch_test/CartPole-v99"
if gym_id not in gym_registry:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
@EnvConfig.register_subclass("_dispatch_base_test")
@dataclass
class _Env(EnvConfig):
task: str = "CartPole-v99"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self):
return "_dispatch_test"
@property
def gym_id(self):
return gym_id
@property
def gym_kwargs(self):
return {}
try:
envs = _Env().create_envs(n_envs=2)
assert "_dispatch_base_test" in envs
env = envs["_dispatch_base_test"][0]
assert isinstance(env, gym.vector.SyncVectorEnv)
assert env.num_envs == 2
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]
def test_custom_create_envs_override():
"""A custom EnvConfig subclass can override create_envs()."""
mock_vec = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
@EnvConfig.register_subclass("_dispatch_custom_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def create_envs(self, n_envs, use_async_envs=False):
return {"custom_suite": {0: mock_vec}}
try:
result = make_env(_Env(), n_envs=1)
assert "custom_suite" in result
finally:
mock_vec.close()
def test_custom_get_env_processors_override():
"""A custom EnvConfig subclass can override get_env_processors()."""
from lerobot.processor.pipeline import PolicyProcessorPipeline
@EnvConfig.register_subclass("_dispatch_proc_test")
@dataclass
class _Env(EnvConfig):
task: str = "x"
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def gym_kwargs(self):
return {}
def get_env_processors(self):
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
pre, post = _Env().get_env_processors()
assert isinstance(pre, PolicyProcessorPipeline)

View File

@@ -0,0 +1,624 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: E402
"""Test script for Multi-Task DiT policy.
To run tests locally:
python -m pytest tests/policies/multi_task_dit/test_multi_task_dit.py -v
"""
import os
import pytest
import torch
from torch import Tensor
pytest.importorskip("transformers")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local transformers installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from lerobot.utils.random_utils import seeded_context, set_seed
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 17
set_seed(seed)
def create_train_batch(
batch_size: int = 2,
n_obs_steps: int = 2,
horizon: int = 16,
state_dim: int = 10,
action_dim: int = 10,
height: int = 224,
width: int = 224,
) -> dict[str, Tensor]:
"""Create a training batch with visual input and text."""
return {
"observation.state": torch.randn(batch_size, n_obs_steps, state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, n_obs_steps, 3, height, width),
ACTION: torch.randn(batch_size, horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def create_observation_batch(
batch_size: int = 2, state_dim: int = 10, height: int = 224, width: int = 224
) -> dict:
"""Create observation batch for inference for a single timestep."""
return {
"observation.state": torch.randn(batch_size, state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, height, width),
"task": ["pick up the red cube"] * batch_size,
}
def create_config(
state_dim: int = 10,
action_dim: int = 10,
n_obs_steps: int = 2,
horizon: int = 16,
n_action_steps: int = 8,
with_visual: bool = True,
height: int = 224,
width: int = 224,
) -> MultiTaskDiTConfig:
"""Create a MultiTaskDiT config for testing.
Args:
state_dim: Dimension of state observations
action_dim: Dimension of actions
n_obs_steps: Number of observation steps
horizon: Action prediction horizon
n_action_steps: Number of action steps to execute
with_visual: Whether to include visual input (default: True)
height: Image height (only used if with_visual=True)
width: Image width (only used if with_visual=True)
"""
input_features = {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}
if with_visual:
input_features[f"{OBS_IMAGES}.laptop"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(3, height, width)
)
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use smaller model for faster tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_dim: int):
"""Test forward pass (training mode)."""
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
assert loss.shape == ()
# Test backward pass
loss.backward()
def test_multi_task_dit_pre_post_processors():
"""Test pre and post processors for Multi-Task DiT policy."""
state_dim = 10
action_dim = 8
n_obs_steps = 2
horizon = 16
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=8,
)
config.device = "cpu"
# Set normalization mode to match the stats we're providing
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
"ACTION": NormalizationMode.MIN_MAX,
}
# Create dataset stats for normalization
dataset_stats = {
"observation.state": {
"mean": torch.zeros(state_dim),
"std": torch.ones(state_dim),
},
"action": {
"min": torch.full((action_dim,), -1.0),
"max": torch.ones(action_dim),
},
}
# Create processors
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
config=config, dataset_stats=dataset_stats
)
# Test preprocessor with sample data
batch = {
"observation.state": torch.randn(state_dim),
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
ACTION: torch.randn(action_dim),
"task": "pick up the cube",
}
processed_batch = preprocessor(batch)
# Check that data is batched
assert processed_batch["observation.state"].shape == (1, state_dim)
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
assert processed_batch[ACTION].shape == (1, action_dim)
# Check that task text was tokenized
assert OBS_LANGUAGE_TOKENS in processed_batch
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
# Check that data is on correct device
assert processed_batch["observation.state"].device.type == "cpu"
assert processed_batch[f"{OBS_IMAGES}.laptop"].device.type == "cpu"
assert processed_batch[ACTION].device.type == "cpu"
# Test postprocessor with sample action (PolicyAction is just a torch.Tensor)
action = torch.randn(1, action_dim)
processed_action = postprocessor(action)
# Check that action is unnormalized and on CPU
assert processed_action.shape == (1, action_dim)
assert processed_action.device.type == "cpu"
def test_multi_task_dit_pre_post_processors_normalization():
"""Test that normalization and unnormalization work correctly with simple sanity check numbers."""
state_dim = 3
action_dim = 2
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=2,
horizon=16,
n_action_steps=8,
)
config.device = "cpu"
# Set normalization mode to match the stats we're providing
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, # Use MEAN_STD since we provide mean/std stats
"ACTION": NormalizationMode.MIN_MAX,
}
# Use simple stats that will actually transform the values
dataset_stats = {
"observation.state": {
"mean": torch.full((state_dim,), 5.0),
"std": torch.full((state_dim,), 2.0),
},
"action": {
"min": torch.zeros(action_dim),
"max": torch.full((action_dim,), 2.0),
},
}
# Create processors
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(
config=config, dataset_stats=dataset_stats
)
# Use simple input values
input_state = torch.tensor([7.0, 5.0, 3.0]) # Will normalize to [1.0, 0.0, -1.0]
input_action = torch.tensor([1.0, 2.0]) # Will normalize to [0.0, 1.0]
batch = {
"observation.state": input_state,
f"{OBS_IMAGES}.laptop": torch.rand(3, 224, 224),
ACTION: input_action,
"task": "test task",
}
# Process through preprocessor
processed_batch = preprocessor(batch)
# State normalization: (x - mean) / std
expected_normalized_state = torch.tensor([1.0, 0.0, -1.0])
assert torch.allclose(processed_batch["observation.state"][0], expected_normalized_state, atol=1e-5)
# Action normalization: (x - min) / (max - min) * 2 - 1
expected_normalized_action = torch.tensor([0.0, 1.0])
assert torch.allclose(processed_batch[ACTION][0], expected_normalized_action, atol=1e-5)
# Test unnormalization: should recover original values
normalized_action_tensor = processed_batch[ACTION][0:1] # Keep batch dimension
unnormalized_action = postprocessor(normalized_action_tensor)
# Should recover original action values
assert torch.allclose(unnormalized_action[0], input_action, atol=1e-4)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 10, 10), (1, 6, 6)])
def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
"""Test select_action (inference mode)."""
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.eval()
policy.reset() # Reset queues before inference
# Create processors - use IDENTITY normalization when no stats provided
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
selected_action = policy.select_action(processed_obs)
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
processed_action = postprocessor(selected_action)
assert processed_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_diffusion_objective():
"""Test policy with diffusion objective."""
batch_size = 2
state_dim = 10
action_dim = 10
n_obs_steps = 2
horizon = 16
n_action_steps = 8
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use diffusion objective
objective="diffusion",
noise_scheduler_type="DDPM",
num_train_timesteps=100,
num_inference_steps=10,
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
# Test inference
policy.eval()
# Use IDENTITY normalization when no stats provided
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
selected_action = policy.select_action(processed_obs)
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
processed_action = postprocessor(selected_action)
assert processed_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_flow_matching_objective():
"""Test policy with flow matching objective."""
batch_size = 2
state_dim = 10
action_dim = 10
n_obs_steps = 2
horizon = 16
n_action_steps = 8
input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config = MultiTaskDiTConfig(
input_features=input_features,
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
# Use flow matching objective
objective="flow_matching",
sigma_min=0.0,
num_integration_steps=10, # Fewer steps for faster tests
integration_method="euler",
# Smaller model for tests
hidden_dim=128,
num_layers=2,
num_heads=4,
)
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.train()
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Process batch through preprocessor to tokenize task text
processed_batch = preprocessor(batch)
# Test forward pass
loss, _ = policy.forward(processed_batch)
assert loss is not None
assert loss.item() is not None
# Test inference
policy.eval()
# Use IDENTITY normalization when no stats provided
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
selected_action = policy.select_action(processed_obs)
# Process action through postprocessor (PolicyAction is just a torch.Tensor)
processed_action = postprocessor(selected_action)
assert processed_action.shape == (batch_size, action_dim)
def test_multi_task_dit_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded correctly."""
root = tmp_path / "test_multi_task_dit_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
n_obs_steps = 2
horizon = 16
n_action_steps = 8
config = create_config(
state_dim=state_dim,
action_dim=action_dim,
n_obs_steps=n_obs_steps,
horizon=horizon,
n_action_steps=n_action_steps,
)
policy = MultiTaskDiTPolicy(config=config)
policy.eval()
# Get device before saving
device = next(policy.parameters()).device
policy.save_pretrained(root)
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
# Explicitly move loaded_policy to the same device
loaded_policy.to(device)
loaded_policy.eval()
batch = create_train_batch(
batch_size=batch_size,
n_obs_steps=n_obs_steps,
horizon=horizon,
state_dim=state_dim,
action_dim=action_dim,
)
# Use preprocessor to handle tokenization
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
with torch.no_grad():
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving
loss, _ = policy.forward(processed_batch)
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
# Process observation through preprocessor
processed_obs = preprocessor(observation_batch)
actions = policy.select_action(processed_obs)
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Collect policy values after loading
loaded_loss, _ = loaded_policy.forward(processed_batch)
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
processed_obs = preprocessor(loaded_observation_batch)
loaded_actions = loaded_policy.select_action(processed_obs)
# Compare state dicts
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
assert torch.allclose(loss, loaded_loss)
assert torch.allclose(actions, loaded_actions)
def test_multi_task_dit_policy_get_optim_params():
"""Test that the policy returns correct optimizer parameter groups."""
config = create_config(
state_dim=10,
action_dim=10,
n_obs_steps=2,
horizon=16,
n_action_steps=8,
)
policy = MultiTaskDiTPolicy(config=config)
param_groups = policy.get_optim_params()
# Should have 2 parameter groups: non-vision and vision encoder
assert len(param_groups) == 2
# First group is non-vision params (no lr specified, will use default)
assert "params" in param_groups[0]
assert len(param_groups[0]["params"]) > 0
# Second group is vision encoder params with different lr
assert "params" in param_groups[1]
assert "lr" in param_groups[1]
expected_lr = config.optimizer_lr * config.vision_encoder_lr_multiplier
assert param_groups[1]["lr"] == expected_lr

View File

@@ -0,0 +1,559 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ActionInterpolator and its interaction with ActionQueue (RTC)."""
import pytest
import torch
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Fixtures ======================
@pytest.fixture
def interp2():
"""Create an ActionInterpolator with multiplier=2."""
return ActionInterpolator(multiplier=2)
@pytest.fixture
def interp3():
"""Create an ActionInterpolator with multiplier=3."""
return ActionInterpolator(multiplier=3)
# ====================== Initialization Tests ======================
def test_interpolator_multiplier_1_no_interpolation():
"""Test multiplier=1 creates a disabled interpolator."""
interp = ActionInterpolator(multiplier=1)
assert interp.multiplier == 1
assert not interp.enabled
def test_interpolator_multiplier_2_enabled():
"""Test multiplier=2 creates an enabled interpolator."""
interp = ActionInterpolator(multiplier=2)
assert interp.multiplier == 2
assert interp.enabled
def test_interpolator_multiplier_0_raises():
"""Test multiplier=0 raises ValueError."""
with pytest.raises(ValueError, match="multiplier must be >= 1"):
ActionInterpolator(multiplier=0)
def test_interpolator_negative_multiplier_raises():
"""Test negative multiplier raises ValueError."""
with pytest.raises(ValueError, match="multiplier must be >= 1"):
ActionInterpolator(multiplier=-1)
def test_interpolator_default_multiplier_is_1():
"""Test default multiplier is 1 (disabled)."""
interp = ActionInterpolator()
assert interp.multiplier == 1
assert not interp.enabled
# ====================== needs_new_action Tests ======================
def test_needs_new_action_true_initially(interp2):
"""Test needs_new_action() returns True before any action is added."""
assert interp2.needs_new_action()
def test_needs_new_action_false_after_add(interp2):
"""Test needs_new_action() returns False right after add()."""
interp2.add(torch.tensor([1.0, 2.0]))
assert not interp2.needs_new_action()
def test_needs_new_action_true_after_buffer_exhausted(interp2):
"""Test needs_new_action() returns True after consuming all buffered actions."""
interp2.add(torch.tensor([1.0, 2.0]))
interp2.get()
assert interp2.needs_new_action()
def test_needs_new_action_true_after_all_interpolated_consumed(interp2):
"""Test needs_new_action() tracks interpolated sub-steps correctly."""
interp2.add(torch.tensor([0.0, 0.0]))
interp2.get()
assert interp2.needs_new_action()
interp2.add(torch.tensor([2.0, 4.0]))
interp2.get()
assert not interp2.needs_new_action()
interp2.get()
assert interp2.needs_new_action()
# ====================== Passthrough Tests (multiplier=1) ======================
def test_passthrough_single_action_returned_as_is():
"""Test multiplier=1 returns the action unchanged."""
interp = ActionInterpolator(multiplier=1)
action = torch.tensor([3.0, 5.0])
interp.add(action)
result = interp.get()
assert result is not None
torch.testing.assert_close(result, action)
def test_passthrough_none_after_single_get():
"""Test multiplier=1 returns None after consuming the single action."""
interp = ActionInterpolator(multiplier=1)
interp.add(torch.tensor([1.0]))
interp.get()
assert interp.get() is None
def test_passthrough_sequential_actions():
"""Test multiplier=1 passes through consecutive actions one at a time."""
interp = ActionInterpolator(multiplier=1)
for val in [1.0, 2.0, 3.0]:
action = torch.tensor([val])
interp.add(action)
result = interp.get()
torch.testing.assert_close(result, action)
assert interp.get() is None
# ====================== Interpolation Tests (multiplier=2) ======================
def test_interpolation_2x_first_action_no_interpolation(interp2):
"""Test first action has no previous, so buffer is just [action]."""
interp2.add(torch.tensor([0.0, 0.0]))
result = interp2.get()
torch.testing.assert_close(result, torch.tensor([0.0, 0.0]))
assert interp2.get() is None
def test_interpolation_2x_second_action_produces_two_steps(interp2):
"""Test second action produces 2 interpolated sub-steps."""
interp2.add(torch.tensor([0.0, 0.0]))
interp2.get()
interp2.add(torch.tensor([2.0, 4.0]))
step1 = interp2.get()
step2 = interp2.get()
torch.testing.assert_close(step1, torch.tensor([1.0, 2.0]))
torch.testing.assert_close(step2, torch.tensor([2.0, 4.0]))
assert interp2.get() is None
def test_interpolation_2x_three_consecutive_actions(interp2):
"""Test interpolation across three consecutive actions."""
a0 = torch.tensor([0.0])
a1 = torch.tensor([4.0])
a2 = torch.tensor([10.0])
interp2.add(a0)
torch.testing.assert_close(interp2.get(), a0)
interp2.add(a1)
torch.testing.assert_close(interp2.get(), torch.tensor([2.0]))
torch.testing.assert_close(interp2.get(), torch.tensor([4.0]))
interp2.add(a2)
torch.testing.assert_close(interp2.get(), torch.tensor([7.0]))
torch.testing.assert_close(interp2.get(), torch.tensor([10.0]))
# ====================== Interpolation Tests (multiplier=3) ======================
def test_interpolation_3x_produces_three_steps(interp3):
"""Test multiplier=3 produces 3 interpolated sub-steps."""
interp3.add(torch.tensor([0.0, 0.0]))
interp3.get()
interp3.add(torch.tensor([3.0, 6.0]))
s1 = interp3.get()
s2 = interp3.get()
s3 = interp3.get()
torch.testing.assert_close(s1, torch.tensor([1.0, 2.0]))
torch.testing.assert_close(s2, torch.tensor([2.0, 4.0]))
torch.testing.assert_close(s3, torch.tensor([3.0, 6.0]))
assert interp3.get() is None
def test_interpolation_3x_last_step_equals_target(interp3):
"""Test last interpolated step equals the target action exactly."""
interp3.add(torch.tensor([10.0]))
interp3.get()
target = torch.tensor([100.0])
interp3.add(target)
interp3.get()
interp3.get()
last = interp3.get()
torch.testing.assert_close(last, target)
# ====================== Reset Tests ======================
def test_reset_clears_buffer(interp2):
"""Test reset() clears the action buffer."""
interp2.add(torch.tensor([1.0]))
interp2.reset()
assert interp2.needs_new_action()
assert interp2.get() is None
def test_reset_clears_prev(interp2):
"""Test after reset, next add produces single-element buffer (no prev)."""
interp2.add(torch.tensor([0.0]))
interp2.get()
interp2.add(torch.tensor([10.0]))
interp2.get()
interp2.get()
interp2.reset()
interp2.add(torch.tensor([5.0]))
result = interp2.get()
torch.testing.assert_close(result, torch.tensor([5.0]))
assert interp2.get() is None
def test_reset_episode_boundary(interp2):
"""Test reset between two simulated episodes."""
interp2.add(torch.tensor([0.0]))
interp2.get()
interp2.add(torch.tensor([10.0]))
interp2.get()
interp2.get()
interp2.reset()
interp2.add(torch.tensor([100.0]))
result = interp2.get()
torch.testing.assert_close(result, torch.tensor([100.0]))
assert interp2.get() is None
# ====================== get_control_interval Tests ======================
def test_control_interval_30fps_multiplier_1():
"""Test control interval at 30fps with no interpolation."""
interp = ActionInterpolator(multiplier=1)
assert interp.get_control_interval(30.0) == pytest.approx(1.0 / 30.0)
def test_control_interval_30fps_multiplier_2(interp2):
"""Test control interval at 30fps with 2x interpolation."""
assert interp2.get_control_interval(30.0) == pytest.approx(1.0 / 60.0)
def test_control_interval_30fps_multiplier_3(interp3):
"""Test control interval at 30fps with 3x interpolation."""
assert interp3.get_control_interval(30.0) == pytest.approx(1.0 / 90.0)
def test_control_interval_60fps_multiplier_2(interp2):
"""Test control interval at 60fps with 2x interpolation."""
assert interp2.get_control_interval(60.0) == pytest.approx(1.0 / 120.0)
# ====================== get() on Empty Tests ======================
def test_get_returns_none_before_any_add():
"""Test get() returns None when no action has been added."""
interp = ActionInterpolator(multiplier=2)
assert interp.get() is None
def test_get_returns_none_after_reset(interp2):
"""Test get() returns None after reset."""
interp2.add(torch.tensor([1.0]))
interp2.reset()
assert interp2.get() is None
# ====================== Multi-Dimensional Action Tests ======================
def test_6dof_interpolation(interp2):
"""Test interpolation works correctly with 6-dimensional actions."""
prev = torch.zeros(6)
target = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
interp2.add(prev)
interp2.get()
interp2.add(target)
mid = interp2.get()
end = interp2.get()
torch.testing.assert_close(mid, target / 2)
torch.testing.assert_close(end, target)
# ====================== Simulated Control Loop Tests ======================
def test_control_loop_produces_correct_action_count():
"""Test N policy actions with multiplier M yields 1 + (N-1)*M robot commands."""
multiplier = 3
n_policy_actions = 5
interp = ActionInterpolator(multiplier=multiplier)
robot_commands = 0
for i in range(n_policy_actions):
action = torch.tensor([float(i)])
if interp.needs_new_action():
interp.add(action)
while True:
a = interp.get()
if a is None:
break
robot_commands += 1
expected = 1 + (n_policy_actions - 1) * multiplier
assert robot_commands == expected
def test_control_loop_monotonic_increase():
"""Test actions [0, 1, 2, 3] with multiplier=2 produce monotonically increasing values."""
interp = ActionInterpolator(multiplier=2)
all_values = []
for i in range(4):
interp.add(torch.tensor([float(i)]))
while True:
a = interp.get()
if a is None:
break
all_values.append(a.item())
for i in range(1, len(all_values)):
assert all_values[i] >= all_values[i - 1]
# ====================== ActionQueue + ActionInterpolator Integration Tests ======================
def _make_chunk(n_steps: int, action_dim: int = 2, offset: float = 0.0) -> torch.Tensor:
"""Create a simple action chunk: each row is [offset + step_idx, offset + step_idx]."""
return torch.arange(n_steps, dtype=torch.float32).unsqueeze(1).expand(-1, action_dim) + offset
def test_queue_interpolator_consumption_rate_matches_base_fps():
"""Test queue.get() is called at base fps rate, not multiplied fps."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=3)
chunk = _make_chunk(10)
queue.merge(chunk, chunk.clone(), real_delay=0)
queue_gets = 0
control_ticks = 0
while True:
if interp.needs_new_action():
if queue.empty():
break
action = queue.get()
if action is None:
break
interp.add(action)
queue_gets += 1
result = interp.get()
if result is not None:
control_ticks += 1
assert queue_gets == 10
assert control_ticks == 1 + 9 * 3
def test_queue_interpolator_leftover_decreases_only_on_queue_get():
"""Test get_left_over() shrinks only on queue.get(), not on interpolator sub-steps."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=3)
chunk = _make_chunk(6)
queue.merge(chunk, chunk.clone(), real_delay=0)
assert interp.needs_new_action()
interp.add(queue.get())
leftover_after_first_get = queue.get_left_over()
assert leftover_after_first_get is not None
assert len(leftover_after_first_get) == 5
interp.get()
assert len(queue.get_left_over()) == 5
interp.add(queue.get())
assert len(queue.get_left_over()) == 4
for _ in range(3):
assert interp.get() is not None
assert len(queue.get_left_over()) == 4
def test_queue_interpolator_processed_leftover_tracks_queue_index():
"""Test get_processed_left_over() reflects queue's last_index, not interpolator state."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
original = _make_chunk(8, offset=0.0)
processed = _make_chunk(8, offset=100.0)
queue.merge(original, processed, real_delay=0)
left = queue.get_processed_left_over()
assert len(left) == 8
for _ in range(3):
if interp.needs_new_action():
action = queue.get()
if action is not None:
interp.add(action)
interp.get()
proc_left = queue.get_processed_left_over()
orig_left = queue.get_left_over()
assert proc_left is not None and orig_left is not None
assert len(proc_left) == len(orig_left)
assert proc_left[0, 0].item() >= 100.0
assert orig_left[0, 0].item() < 100.0
def test_queue_interpolator_merge_resets_queue_but_interpolator_keeps_prev():
"""Test queue merge doesn't affect interpolator's prev, enabling smooth transitions."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
chunk1 = torch.tensor([[0.0], [2.0], [4.0], [6.0], [8.0]])
queue.merge(chunk1, chunk1.clone(), real_delay=0)
consumed = []
for _ in range(5):
if interp.needs_new_action():
a = queue.get()
if a is not None:
interp.add(a)
r = interp.get()
if r is not None:
consumed.append(r.item())
assert interp.needs_new_action()
assert consumed[-1] == pytest.approx(4.0)
idx_before = queue.get_action_index()
chunk2 = torch.tensor([[10.0], [12.0], [14.0]])
queue.merge(chunk2, chunk2.clone(), real_delay=0, action_index_before_inference=idx_before)
first_action = queue.get()
assert first_action is not None
interp.add(first_action)
first_from_new = interp.get()
assert first_from_new is not None
assert first_from_new.item() == pytest.approx(7.0)
def test_queue_interpolator_reset_does_not_affect_queue():
"""Test interpolator reset leaves queue state untouched."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
chunk = _make_chunk(5)
queue.merge(chunk, chunk.clone(), real_delay=0)
interp.add(queue.get())
interp.get()
interp.add(queue.get())
interp.get()
interp.get()
assert queue.qsize() == 3
interp.reset()
assert queue.qsize() == 3
assert len(queue.get_left_over()) == 3
interp.add(queue.get())
result = interp.get()
assert result is not None
assert queue.qsize() == 2
def test_queue_interpolator_no_interpolation_1_to_1():
"""Test multiplier=1 produces exactly 1 robot command per queue.get()."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=1)
chunk = _make_chunk(5)
queue.merge(chunk, chunk.clone(), real_delay=0)
robot_commands = 0
while not queue.empty():
if interp.needs_new_action():
action = queue.get()
if action is not None:
interp.add(action)
result = interp.get()
if result is not None:
robot_commands += 1
assert robot_commands == 5
def test_queue_interpolator_delay_skips_stale_actions():
"""Test merge with delay correctly skips stale actions for the interpolator."""
cfg = RTCConfig(enabled=True, execution_horizon=10)
queue = ActionQueue(cfg)
interp = ActionInterpolator(multiplier=2)
chunk1 = _make_chunk(10)
queue.merge(chunk1, chunk1.clone(), real_delay=0)
for _ in range(5):
if interp.needs_new_action():
a = queue.get()
if a is not None:
interp.add(a)
interp.get()
assert queue.get_action_index() == 3
chunk2 = _make_chunk(10, offset=100.0)
queue.merge(chunk2, chunk2.clone(), real_delay=3, action_index_before_inference=0)
first_action = queue.get()
assert first_action is not None
torch.testing.assert_close(first_action, torch.tensor([103.0, 103.0]))

View File

@@ -25,7 +25,7 @@ import torch
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
# ====================== Fixtures ======================
# Fixtures
@pytest.fixture
@@ -63,7 +63,7 @@ def action_queue_rtc_disabled(rtc_config_disabled):
return ActionQueue(rtc_config_disabled)
# ====================== Initialization Tests ======================
# Initialization tests
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
@@ -84,7 +84,7 @@ def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
assert queue.cfg.enabled is False
# ====================== get() Tests ======================
# get() tests
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
@@ -136,7 +136,7 @@ def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
assert action_queue_rtc_enabled.last_index == 2
# ====================== qsize() Tests ======================
# qsize() tests
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
@@ -167,7 +167,7 @@ def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
assert action_queue_rtc_enabled.qsize() == 0
# ====================== empty() Tests ======================
# empty() tests
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
@@ -202,7 +202,7 @@ def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
assert action_queue_rtc_enabled.empty() is True
# ====================== get_action_index() Tests ======================
# get_action_index() tests
def test_get_action_index_initial_value(action_queue_rtc_enabled):
@@ -222,7 +222,7 @@ def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_act
assert action_queue_rtc_enabled.get_action_index() == 3
# ====================== get_left_over() Tests ======================
# get_left_over() tests
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
@@ -269,7 +269,7 @@ def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled,
assert leftover.shape == (0, 6)
# ====================== merge() with RTC Enabled Tests ======================
# merge() with RTC enabled tests
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
@@ -336,7 +336,7 @@ def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
assert action_queue_rtc_enabled.qsize() == 0
# ====================== merge() with RTC Disabled Tests ======================
# merge() with RTC disabled tests
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
@@ -402,7 +402,7 @@ def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_ac
assert action_queue_rtc_disabled.last_index == 0
# ====================== merge() with Different Action Shapes Tests ======================
# merge() with different action shapes tests
def test_merge_with_different_action_dims():
@@ -431,7 +431,7 @@ def test_merge_with_different_lengths():
assert queue.qsize() == 35
# ====================== merge() Delay Validation Tests ======================
# merge() delay validation tests
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
@@ -509,7 +509,7 @@ def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled,
assert "Indexes diff is not equal to real delay" not in caplog.text
# ====================== Thread Safety Tests ======================
# Thread safety tests
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
@@ -621,7 +621,7 @@ def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
assert consumed_count[0] <= 200
# ====================== get_left_over() Thread Safety Tests ======================
# get_left_over() thread safety tests
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
@@ -670,7 +670,7 @@ def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
assert len(leftovers) > 0
# ====================== Edge Cases Tests ======================
# Edge cases tests
def test_queue_with_single_action(action_queue_rtc_enabled):
@@ -773,7 +773,7 @@ def test_qsize_with_none_queue(action_queue_rtc_enabled):
assert action_queue_rtc_enabled.qsize() == 0
# ====================== Integration Tests ======================
# Integration tests
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):

View File

@@ -0,0 +1,607 @@
"""Tests for RTC + relative actions integration.
Validates that Real-Time Chunking (RTC) works correctly when the policy uses
relative actions. The key invariant: RTC guidance operates in model space
(normalized relative actions), while the robot receives absolute actions after postprocessing.
Flow under test:
Preprocessor: raw obs → relative step caches state → normalizer
Model: generates normalized relative actions (guided by RTC using leftover relative actions)
Postprocessor: unnormalize → absolute step (relative + cached state) → robot actions
"""
import importlib.util
import sys
from pathlib import Path
import torch
from lerobot.configs.types import (
FeatureType,
NormalizationMode,
PolicyFeature,
RTCAttentionSchedule,
)
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
RelativeActionsProcessorStep,
to_relative_actions,
)
from lerobot.utils.constants import ACTION, OBS_STATE
def _import_rtc_module(module_name: str, filename: str):
"""Import an RTC module directly from its file path, bypassing lerobot.policies.__init__."""
rtc_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "policies" / "rtc"
spec = importlib.util.spec_from_file_location(module_name, rtc_dir / filename)
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
return mod
_rtc_cfg_mod = _import_rtc_module("lerobot.policies.rtc.configuration_rtc", "configuration_rtc.py")
RTCConfig = _rtc_cfg_mod.RTCConfig
_action_queue_mod = _import_rtc_module("lerobot.policies.rtc.action_queue", "action_queue.py")
ActionQueue = _action_queue_mod.ActionQueue
_rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug_tracker.py")
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
RTCProcessor = _rtc_mod.RTCProcessor
ACTION_DIM = 6
CHUNK_SIZE = 50
EXECUTION_HORIZON = 10
def _make_rtc_config(enabled=True):
return RTCConfig(
enabled=enabled,
execution_horizon=EXECUTION_HORIZON,
max_guidance_weight=10.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.MEAN_STD):
"""Build paired relative/absolute processor steps and normalizer/unnormalizer."""
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: norm_mode}
stats = {
ACTION: {
"mean": torch.zeros(action_dim).numpy(),
"std": torch.ones(action_dim).numpy(),
"q01": (-2 * torch.ones(action_dim)).numpy(),
"q99": (2 * torch.ones(action_dim)).numpy(),
"min": (-3 * torch.ones(action_dim)).numpy(),
"max": (3 * torch.ones(action_dim)).numpy(),
}
}
relative_step = RelativeActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
return relative_step, normalizer, unnormalizer, absolute_step
class TestActionQueueRelativeActions:
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
def test_left_over_returns_relative_actions(self):
"""get_left_over() should return the original (relative-space) actions."""
cfg = _make_rtc_config()
queue = ActionQueue(cfg)
relative_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
absolute_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
queue.merge(relative_actions, absolute_actions, real_delay=0)
for _ in range(5):
queue.get()
leftover = queue.get_left_over()
torch.testing.assert_close(leftover, relative_actions[5:])
def test_robot_receives_absolute_actions(self):
"""The robot (via get()) should receive postprocessed absolute actions."""
cfg = _make_rtc_config()
queue = ActionQueue(cfg)
relative_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
absolute_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
queue.merge(relative_actions, absolute_actions, real_delay=0)
first_action = queue.get()
torch.testing.assert_close(first_action, absolute_actions[0])
class TestRTCDenoiseWithRelativeLeftovers:
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""
def test_first_chunk_no_guidance(self):
"""First chunk (no leftovers) should return v_t without guidance."""
rtc = RTCProcessor(_make_rtc_config())
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
def mock_denoise(x):
return torch.ones_like(x)
result = rtc.denoise_step(
x_t=x_t,
prev_chunk_left_over=None,
inference_delay=0,
time=0.5,
original_denoise_step_partial=mock_denoise,
)
torch.testing.assert_close(result, torch.ones_like(x_t))
def test_relative_leftovers_shape_preserved(self):
"""RTC output should have the same shape as input regardless of leftover shape."""
rtc = RTCProcessor(_make_rtc_config())
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
shorter_leftover = torch.randn(1, 20, ACTION_DIM)
def mock_denoise(x):
return torch.zeros_like(x)
result = rtc.denoise_step(
x_t=x_t,
prev_chunk_left_over=shorter_leftover,
inference_delay=5,
time=0.5,
original_denoise_step_partial=mock_denoise,
)
assert result.shape == x_t.shape
def test_guidance_steers_toward_previous_relative_actions(self):
"""RTC guidance should push x1_t toward prev_chunk_left_over in relative space."""
rtc = RTCProcessor(_make_rtc_config())
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
prev_relatives = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
def mock_denoise(x):
return torch.zeros_like(x)
result_without_guidance = rtc.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=None,
inference_delay=5,
time=0.5,
original_denoise_step_partial=mock_denoise,
)
result_with_guidance = rtc.denoise_step(
x_t=x_t.clone(),
prev_chunk_left_over=prev_relatives,
inference_delay=5,
time=0.5,
original_denoise_step_partial=mock_denoise,
)
assert not torch.allclose(result_with_guidance, result_without_guidance, atol=1e-6)
class TestFullPipelineRelativeRTC:
"""End-to-end test of the RTC + relative actions pipeline matching eval_with_real_robot.py flow."""
def test_preprocessor_caches_state_for_postprocessor(self):
"""Preprocessor's relative step should cache state so postprocessor can convert back."""
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
state = torch.randn(1, ACTION_DIM)
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
batch = {ACTION: actions, OBS_STATE: state}
transition = batch_to_transition(batch)
relative_step(transition)
assert relative_step._last_state is not None
torch.testing.assert_close(relative_step._last_state, state)
def test_preprocessor_caches_state_without_actions(self):
"""During inference, preprocessor receives only observations (no actions).
Relative step should still cache state for the postprocessor."""
relative_step, _, _, _ = _make_relative_pipeline()
state = torch.randn(1, ACTION_DIM)
batch = {OBS_STATE: state}
transition = batch_to_transition(batch)
relative_step(transition)
assert relative_step._last_state is not None
torch.testing.assert_close(relative_step._last_state, state)
def test_roundtrip_with_identity_normalization(self):
"""Actions → relative → normalize → [model] → unnormalize → absolute should recover originals.
Using mean=0, std=1 normalization (identity)."""
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
state = torch.randn(1, ACTION_DIM)
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
batch = {ACTION: actions.clone(), OBS_STATE: state}
transition = batch_to_transition(batch)
t1 = relative_step(transition)
t2 = normalizer(t1)
model_output = t2[TransitionKey.ACTION].clone()
model_transition = {TransitionKey.ACTION: model_output}
t3 = unnormalizer(model_transition)
t4 = absolute_step(t3)
recovered = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5)
def test_eval_loop_simulation(self):
"""Simulate the eval_with_real_robot.py loop with relative actions.
Iteration 1: No leftovers → model generates relative actions → store for RTC
Iteration 2: Use leftovers as RTC guidance → model generates new relative actions
Both iterations: postprocessor converts relative actions to absolute for robot
"""
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
rtc = RTCProcessor(_make_rtc_config())
queue = ActionQueue(_make_rtc_config())
def mock_model(prev_chunk_left_over, inference_delay, state):
"""Simulate model generating relative actions with RTC."""
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
def denoise(x):
return -0.1 * x
guided_v = rtc.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=0.5,
original_denoise_step_partial=denoise,
)
return x_t - 0.5 * guided_v
# --- Iteration 1: first chunk, no leftovers ---
state_1 = torch.randn(1, ACTION_DIM)
obs_batch_1 = {OBS_STATE: state_1}
relative_step(batch_to_transition(obs_batch_1))
model_relatives_1 = mock_model(prev_chunk_left_over=None, inference_delay=0, state=state_1)
original_actions_1 = model_relatives_1.squeeze(0)
model_transition_1 = {TransitionKey.ACTION: model_relatives_1}
postprocessed_1 = absolute_step(unnormalizer(model_transition_1))[TransitionKey.ACTION].squeeze(0)
queue.merge(original_actions_1, postprocessed_1, real_delay=0)
# Consume some actions (simulate robot executing)
for _ in range(5):
action = queue.get()
assert action is not None
# --- Iteration 2: use leftovers for RTC ---
prev_actions = queue.get_left_over()
assert prev_actions is not None
assert prev_actions.shape[0] == CHUNK_SIZE - 5
state_2 = state_1 + 0.01 * torch.randn(1, ACTION_DIM)
obs_batch_2 = {OBS_STATE: state_2}
relative_step(batch_to_transition(obs_batch_2))
model_relatives_2 = mock_model(
prev_chunk_left_over=prev_actions.unsqueeze(0), inference_delay=3, state=state_2
)
original_actions_2 = model_relatives_2.squeeze(0)
model_transition_2 = {TransitionKey.ACTION: model_relatives_2}
postprocessed_2 = absolute_step(unnormalizer(model_transition_2))[TransitionKey.ACTION].squeeze(0)
queue.merge(original_actions_2, postprocessed_2, real_delay=3)
# Postprocessed actions should be in absolute space
action = queue.get()
assert action is not None
assert action.shape == (ACTION_DIM,)
# Verify leftovers are in relative space (original_queue stores relative actions)
leftover_relatives = queue.get_left_over()
assert leftover_relatives is not None
assert leftover_relatives.shape[1] == ACTION_DIM
def test_postprocessor_uses_correct_state_per_iteration(self):
"""Each iteration's postprocessor should use the state from that iteration's preprocessor,
not a stale state from a previous iteration."""
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
state_1 = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]])
state_2 = torch.tensor([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]])
relatives = torch.zeros(1, 5, ACTION_DIM)
# Iteration 1: cache state_1
relative_step(batch_to_transition({OBS_STATE: state_1}))
result_1 = absolute_step(unnormalizer({TransitionKey.ACTION: relatives.clone()}))[
TransitionKey.ACTION
]
# relative=0 + state_1 should give state_1
for t in range(5):
torch.testing.assert_close(result_1[0, t], state_1[0], atol=1e-5, rtol=1e-5)
# Iteration 2: cache state_2
relative_step(batch_to_transition({OBS_STATE: state_2}))
result_2 = absolute_step(unnormalizer({TransitionKey.ACTION: relatives.clone()}))[
TransitionKey.ACTION
]
for t in range(5):
torch.testing.assert_close(result_2[0, t], state_2[0], atol=1e-5, rtol=1e-5)
class TestStateRebasingApproximation:
"""Verify that the approximation from not rebasing leftover relative actions is small
when state changes between inference calls are small (real-time control regime)."""
def test_small_state_change_produces_small_error(self):
"""With small state changes (typical in real-time control),
using stale relative actions for RTC guidance introduces negligible error."""
state_old = torch.randn(1, ACTION_DIM)
state_new = state_old + 0.01 * torch.randn(1, ACTION_DIM)
actions_absolute = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
mask = [True] * ACTION_DIM
relatives_old = to_relative_actions(actions_absolute, state_old, mask)
relatives_new = to_relative_actions(actions_absolute, state_new, mask)
error = (relatives_old - relatives_new).abs().mean()
state_change = (state_old - state_new).abs().mean()
# Error should be proportional to state change
assert error < 0.1, (
f"Relative-action error {error:.4f} should be small for small state change {state_change:.4f}"
)
def test_large_state_change_produces_proportional_error(self):
"""With large state changes, stale relative actions diverge more (but RTC guidance decays)."""
state_old = torch.randn(1, ACTION_DIM)
state_new = state_old + 10.0 * torch.randn(1, ACTION_DIM)
actions_absolute = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
mask = [True] * ACTION_DIM
relatives_old = to_relative_actions(actions_absolute, state_old, mask)
relatives_new = to_relative_actions(actions_absolute, state_new, mask)
error = (relatives_old - relatives_new).abs().mean()
state_change = (state_old - state_new).abs().mean()
# Error should be roughly equal to state change
torch.testing.assert_close(
error.clone().detach(), state_change.clone().detach(), atol=1e-5, rtol=1e-5
)
def test_excluded_joints_not_affected_by_state_change(self):
"""Joints excluded from relative conversion should not contribute rebasing error."""
state_old = torch.randn(1, ACTION_DIM)
state_new = state_old.clone()
state_new[0, -1] = state_old[0, -1] + 100.0
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
mask = [True] * (ACTION_DIM - 1) + [False]
relatives_old = to_relative_actions(actions, state_old, mask)
relatives_new = to_relative_actions(actions, state_new, mask)
# Last dim (excluded) should have zero error
error_excluded = (relatives_old[..., -1] - relatives_new[..., -1]).abs().max()
assert error_excluded < 1e-6, f"Excluded joint should have zero error, got {error_excluded}"
def _detect_relative_actions(preprocessor) -> bool:
"""Mirror of the helper in eval_with_real_robot.py for testing without importing it."""
return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps)
class TestDetectRelativeActions:
"""Test the _detect_relative_actions helper logic used by eval_with_real_robot.py."""
def test_detects_enabled_relative_step(self):
class FakePipeline:
steps = [RelativeActionsProcessorStep(enabled=True)]
assert _detect_relative_actions(FakePipeline()) is True
def test_ignores_disabled_relative_step(self):
class FakePipeline:
steps = [RelativeActionsProcessorStep(enabled=False)]
assert _detect_relative_actions(FakePipeline()) is False
def test_returns_false_when_no_relative_step(self):
class FakePipeline:
steps = []
assert _detect_relative_actions(FakePipeline()) is False
class TestNonRelativePolicy:
"""Verify the same pipeline works when relative actions are disabled (standard absolute policy)."""
def test_disabled_relative_step_is_noop(self):
relative_step = RelativeActionsProcessorStep(enabled=False)
absolute_step = AbsoluteActionsProcessorStep(enabled=False, relative_step=relative_step)
state = torch.randn(1, ACTION_DIM)
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
transition = batch_to_transition({ACTION: actions.clone(), OBS_STATE: state})
t1 = relative_step(transition)
torch.testing.assert_close(t1[TransitionKey.ACTION], actions)
t2 = absolute_step({TransitionKey.ACTION: actions.clone()})
torch.testing.assert_close(t2[TransitionKey.ACTION], actions)
def test_eval_loop_without_relative_actions(self):
"""Full eval loop simulation with relative actions disabled: original and processed actions are identical."""
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
stats = {
ACTION: {
"mean": torch.zeros(ACTION_DIM).numpy(),
"std": torch.ones(ACTION_DIM).numpy(),
}
}
relative_step = RelativeActionsProcessorStep(enabled=False)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=False, relative_step=relative_step)
rtc = RTCProcessor(_make_rtc_config())
queue = ActionQueue(_make_rtc_config())
state = torch.randn(1, ACTION_DIM)
relative_step(batch_to_transition({OBS_STATE: state}))
model_output = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_output.clone()}))[
TransitionKey.ACTION
].squeeze(0)
original = model_output.squeeze(0)
# With identity norm and no relative-action transform, original and postprocessed should match
torch.testing.assert_close(original, post, atol=1e-5, rtol=1e-5)
queue.merge(original, post, real_delay=0)
for _ in range(5):
queue.get()
prev_actions = queue.get_left_over()
assert prev_actions is not None
# RTC guidance works the same way (absolute space)
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
result = rtc.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_actions.unsqueeze(0),
inference_delay=3,
time=0.5,
original_denoise_step_partial=lambda x: torch.zeros_like(x),
)
assert result.shape == x_t.shape
def test_detect_relative_returns_false_when_disabled(self):
class FakePipeline:
steps = [RelativeActionsProcessorStep(enabled=False)]
assert not _detect_relative_actions(FakePipeline())
def test_detect_relative_returns_false_when_absent(self):
class FakePipeline:
steps = []
assert not _detect_relative_actions(FakePipeline())
class TestMultiChunkConsistency:
"""Test multiple RTC iterations with relative actions maintain consistency."""
def test_three_iteration_pipeline(self):
"""Simulate three consecutive RTC iterations and verify queue state consistency."""
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
queue = ActionQueue(_make_rtc_config())
states = [torch.randn(1, ACTION_DIM) + i * 0.01 for i in range(3)]
for i in range(3):
queue.get_left_over()
relative_step(batch_to_transition({OBS_STATE: states[i]}))
model_output = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
post_transition = absolute_step(unnormalizer({TransitionKey.ACTION: model_output.clone()}))
postprocessed = post_transition[TransitionKey.ACTION].squeeze(0)
original = model_output.squeeze(0)
delay = min(i * 2, CHUNK_SIZE - 1)
queue.merge(original, postprocessed, real_delay=delay)
for _ in range(5):
action = queue.get()
assert action is not None
assert action.shape == (ACTION_DIM,)
# After 3 iterations, queue should still be in valid state
assert queue.qsize() > 0
leftover = queue.get_left_over()
assert leftover is not None
assert leftover.ndim == 2
assert leftover.shape[1] == ACTION_DIM
def test_leftover_and_processed_differ_when_relative_enabled(self):
"""With relative actions enabled, original leftovers (relative) must differ from processed (absolute)."""
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
queue = ActionQueue(_make_rtc_config())
state = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]])
relative_step(batch_to_transition({OBS_STATE: state}))
model_relatives = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_relatives.clone()}))[
TransitionKey.ACTION
].squeeze(0)
original = model_relatives.squeeze(0)
queue.merge(original, post, real_delay=0)
relative_leftover = queue.get_left_over()
# Leftovers (relative) must differ from the postprocessed absolute actions
assert not torch.allclose(relative_leftover, post, atol=1e-3)
state_expanded = state.squeeze(0).unsqueeze(0).expand_as(relative_leftover)
torch.testing.assert_close(post, relative_leftover + state_expanded, atol=1e-5, rtol=1e-5)
def test_rtc_guidance_uses_relative_space(self):
"""Verify that RTC denoise_step receives relative-space leftovers, not absolute."""
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
rtc = RTCProcessor(_make_rtc_config())
queue = ActionQueue(_make_rtc_config())
state = torch.tensor([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]])
relative_step(batch_to_transition({OBS_STATE: state}))
model_relatives = 0.1 * torch.randn(1, CHUNK_SIZE, ACTION_DIM)
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_relatives.clone()}))[
TransitionKey.ACTION
].squeeze(0)
original = model_relatives.squeeze(0)
queue.merge(original, post, real_delay=0)
for _ in range(5):
queue.get()
prev_left_over = queue.get_left_over()
# prev_left_over should be small relative offsets (around 0.1 * randn), not large absolute values
assert prev_left_over.abs().mean() < 5.0, (
f"Leftover should be small relative offsets, got mean abs {prev_left_over.abs().mean():.2f}"
)
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
def denoise(x):
return torch.zeros_like(x)
result = rtc.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_left_over.unsqueeze(0),
inference_delay=3,
time=0.5,
original_denoise_step_partial=denoise,
)
assert result.shape == x_t.shape

View File

@@ -42,6 +42,8 @@ from lerobot.policies.factory import (
make_pre_post_processors,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
@@ -460,3 +462,45 @@ def test_act_temporal_ensembler():
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
def test_vqbet_discretize_keeps_buffers_on_device():
"""Regression test: VQBeTHead.discretize() must not move registered buffers off the model device.
Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the
registered buffer with a new CPU tensor, causing DDP to crash with:
RuntimeError: No backend type associated with device type cpu
The fix uses `.fill_(True)` to update in-place, preserving device placement.
"""
config = VQBeTConfig()
config.input_features = {
OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)),
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
}
# Tiny sizes for fast CPU/GPU execution.
config.n_vqvae_training_steps = 3
config.vqvae_n_embed = 8
config.vqvae_embedding_dim = 32
config.vqvae_enc_hidden_dim = 32
config.action_chunk_size = 2
config.crop_shape = (84, 84)
head = VQBeTHead(config).to(DEVICE)
vqvae = head.vqvae_model
dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE)
n_steps = config.n_vqvae_training_steps
for _ in range(n_steps):
head.discretize(n_steps, dummy_actions)
assert vqvae.discretized.device.type == torch.device(DEVICE).type, (
"vqvae_model.discretized was moved off the model device after discretize(). "
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
)
assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, (
"vq_layer.freeze_codebook was moved off the model device after discretize(). "
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
)

View File

@@ -0,0 +1,346 @@
"""Tests for relative action transforms — full pipeline validation.
Tests the complete flow matching OpenPI:
raw actions → RelativeActions → Normalize(relative_stats) → model → Unnormalize → AbsoluteActions
Uses real dataset: lerobot-data-collection/dagger_final_1_21
"""
import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
RelativeActionsProcessorStep,
to_absolute_actions,
to_relative_actions,
)
from lerobot.utils.constants import ACTION, OBS_STATE
CHUNK_SIZE = 10
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
@pytest.fixture(scope="module")
def dataset():
return LeRobotDataset(REPO_ID, episodes=[0])
@pytest.fixture(scope="module")
def action_dim(dataset):
return dataset.meta.features["action"]["shape"][0]
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
"""Build action chunks from hf_dataset, like the training script does."""
hf = dataset.hf_dataset
total = len(hf)
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
chunks, states = [], []
for i in range(total - chunk_size + 1):
if all_ep[i] != all_ep[i + chunk_size - 1]:
continue
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
state = hf[i]["observation.state"].float()
chunks.append(chunk_actions)
states.append(state)
if len(chunks) >= max_chunks:
break
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
return torch.stack(chunks), torch.stack(states)
def _compute_relative_chunk_stats(action_chunks, states, mask):
all_chunks = []
for actions, state in zip(action_chunks, states, strict=True):
relative = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_chunks.append(relative.numpy())
all_relative = np.concatenate(all_chunks, axis=0)
return get_feature_stats(all_relative, axis=0, keepdims=all_relative.ndim == 1)
# Basic roundtrip tests
def test_roundtrip_3d(action_dim):
actions = torch.randn(4, CHUNK_SIZE, action_dim)
state = torch.randn(4, action_dim)
mask = [True] * action_dim
recovered = to_absolute_actions(to_relative_actions(actions, state, mask), state, mask)
torch.testing.assert_close(recovered, actions)
def test_roundtrip_2d(action_dim):
actions = torch.randn(4, action_dim)
state = torch.randn(4, action_dim)
mask = [True] * action_dim
recovered = to_absolute_actions(to_relative_actions(actions, state, mask), state, mask)
torch.testing.assert_close(recovered, actions)
def test_no_mutation(action_dim):
actions = torch.randn(2, CHUNK_SIZE, action_dim)
original = actions.clone()
state = torch.randn(2, action_dim)
to_relative_actions(actions, state, [True] * action_dim)
torch.testing.assert_close(actions, original)
def test_exclude_joints_supports_partial_name_matching():
names = [
"right_joint_1.pos",
"right_gripper.pos",
"left_joint_1.pos",
"left_gripper.pos",
]
step = RelativeActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
assert step._build_mask(len(names)) == [True, False, True, False]
# Chunk-level relative stats test
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
"""Chunk-level relative stats should have larger std than per-frame relative stats."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
chunk_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
# Per-frame stats
hf = dataset.hf_dataset
n = min(500, len(hf))
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
frame_relatives = to_relative_actions(frame_actions, frame_states, mask).numpy()
frame_stats = get_feature_stats(frame_relatives, axis=0, keepdims=frame_relatives.ndim == 1)
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
f"frame std ({frame_stats['std'].mean():.4f})"
)
# Full pipeline roundtrip: relative → normalize → unnormalize → absolute
def test_full_pipeline_roundtrip(dataset, action_dim):
"""Test the complete OpenPI pipeline: relative → normalize → unnormalize → absolute."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
stats = {ACTION: dict(relative_stats.items())}
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
relative_step = RelativeActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
original_actions = action_chunks[0].unsqueeze(0)
state = states[0].unsqueeze(0)
batch = {ACTION: original_actions, OBS_STATE: state}
transition = batch_to_transition(batch)
# Forward: relative → normalize
t1 = relative_step(transition)
t2 = normalizer(t1)
normalized_action = t2[TransitionKey.ACTION]
assert normalized_action.abs().mean() < 10, (
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
)
# Reverse: unnormalize → absolute
t3 = unnormalizer(t2)
t4 = absolute_step(t3)
recovered_actions = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
def test_normalized_relative_values_are_reasonable(dataset, action_dim):
"""With correct chunk stats, normalized relative actions should be in a reasonable range."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
mean = torch.tensor(relative_stats["mean"]).float()
std = torch.tensor(relative_stats["std"]).float()
all_normalized = []
for actions, state in zip(action_chunks, states, strict=True):
relative = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
normalized = (relative - mean) / (std + 1e-6)
all_normalized.append(normalized)
all_normalized = torch.cat(all_normalized, dim=0)
pct_in_range = (all_normalized.abs() < 5).float().mean()
assert pct_in_range > 0.9, (
f"Only {pct_in_range * 100:.1f}% of normalized values in [-5, 5], expected >90%"
)
assert all_normalized.mean().abs() < 1.0, (
f"Mean of normalized relative actions is {all_normalized.mean():.2f}, expected near 0"
)
def test_processor_step_roundtrip(dataset, action_dim):
"""RelativeActionsProcessorStep applies relative offsets; to_absolute_actions recovers original."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
}
original_actions = batch[ACTION].clone()
transition = batch_to_transition(batch)
step = RelativeActionsProcessorStep(enabled=True)
relative_transition = step(transition)
assert not torch.allclose(relative_transition[TransitionKey.ACTION], original_actions)
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
mask = [True] * action_dim
recovered = to_absolute_actions(relative_transition[TransitionKey.ACTION], state, mask)
torch.testing.assert_close(recovered, original_actions)
def test_processor_step_disabled_is_noop(dataset, action_dim):
"""enabled=False should be a no-op."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
}
original = batch[ACTION].clone()
transition = batch_to_transition(batch)
result = RelativeActionsProcessorStep(enabled=False)(transition)
torch.testing.assert_close(result[TransitionKey.ACTION], original)
# Training batch shape validation
def test_relative_with_action_chunks(dataset, action_dim):
"""Verify relative actions work correctly with (B, chunk_size, action_dim) shaped actions."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
batch_states = states[:4] # (4, state_dim)
mask = [True] * action_dim
relative = to_relative_actions(batch_actions, batch_states, mask)
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
first_relatives = relative[:, 0, :] # (B, action_dim)
assert first_relatives.abs().mean() < relative.abs().mean(), (
f"First action in chunk should have smaller relative offset than average. "
f"First: {first_relatives.abs().mean():.4f}, Average: {relative.abs().mean():.4f}"
)
# Later actions should have larger relative offsets
last_relatives = relative[:, -1, :] # (B, action_dim)
assert last_relatives.abs().mean() >= first_relatives.abs().mean(), (
f"Last action in chunk should have >= relative offset than first. "
f"Last: {last_relatives.abs().mean():.4f}, First: {first_relatives.abs().mean():.4f}"
)
# Roundtrip
recovered = to_absolute_actions(relative, batch_states, mask)
torch.testing.assert_close(recovered, batch_actions)
def test_relative_stats_match_actual_data_distribution(dataset, action_dim):
"""Verify computed stats match the actual relative-action distribution."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
# Compute stats like the training script does
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
# Also compute directly
all_relatives = []
for actions, state in zip(action_chunks, states, strict=True):
rel = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_relatives.append(rel)
all_relatives_tensor = torch.cat(all_relatives, dim=0)
# Compare mean
actual_mean = all_relatives_tensor.mean(dim=0).numpy()
np.testing.assert_allclose(relative_stats["mean"], actual_mean, atol=0.01)
# Compare std
actual_std = all_relatives_tensor.std(dim=0).numpy()
np.testing.assert_allclose(relative_stats["std"], actual_std, atol=0.1)
# Verify q01 < mean < q99
assert (relative_stats["q01"] < relative_stats["mean"]).all(), "q01 should be < mean"
assert (relative_stats["mean"] < relative_stats["q99"]).all(), "mean should be < q99"
def test_quantile_normalization_roundtrip(dataset, action_dim):
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
stats = {ACTION: dict(relative_stats.items())}
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
relative_step = RelativeActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
original_actions = action_chunks[0].unsqueeze(0)
state = states[0].unsqueeze(0)
batch = {ACTION: original_actions, OBS_STATE: state}
transition = batch_to_transition(batch)
# Forward: relative → quantile normalize
t1 = relative_step(transition)
t2 = normalizer(t1)
normalized = t2[TransitionKey.ACTION]
# Most values should be in [-1, 1] with quantile normalization
pct_in_range = (normalized.abs() < 2).float().mean()
assert pct_in_range > 0.5, f"Only {pct_in_range * 100:.1f}% in [-2, 2] after quantile norm, expected >50%"
# Reverse: unnormalize → absolute
t3 = unnormalizer(t2)
t4 = absolute_step(t3)
recovered = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
def test_state_not_modified_by_relative_processor(dataset, action_dim):
"""State should never be modified by the relative-action processor."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
}
original_state = batch[OBS_STATE].clone()
transition = batch_to_transition(batch)
step = RelativeActionsProcessorStep(enabled=True)
result = step(transition)
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
torch.testing.assert_close(result_state, original_state)