Compare commits

..

14 Commits

Author SHA1 Message Date
Nikodem Bartnik
a24d10f5bb add quick AI draft for quickstart 2026-05-26 13:10:24 +02:00
Nikodem Bartnik
32279544ea add new docs chapters structure 2026-05-26 12:01:33 +02:00
Pepijn
8194897994 fix(deps): cap placo below 0.9.16 and harden kinematics import (#3647)
* fix(deps): cap placo below 0.9.16 and harden kinematics import

placo 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable
on Ubuntu 24.04 (noble ships urdfdom 3.x). Importing placo on that base
crashes with:

  ImportError: liburdfdom_sensor.so.4.0: cannot open shared object file

This broke nightly Latest Deps tests (CPU and GPU) when the lockfile
upgrade picked placo 0.9.16, since lerobot.model.kinematics
unconditionally imports placo when _placo_available is true, and that
check (importlib.util.find_spec) cannot detect dlopen failures of
transitive shared libraries — so unrelated subsystems (RL actor,
gym_manipulator) became unimportable.

Two changes:

1. Pin placo to <0.9.16 in pyproject.toml + regenerate uv.lock
   (0.9.16 → 0.9.15). Short-term unblock for nightly CI until system
   urdfdom 4.x is broadly available.

2. Harden the import guard in src/lerobot/model/kinematics.py:
   wrap 'import placo' in try/except ImportError so a missing
   transitive .so no longer crashes module import. RobotKinematics
   instantiation now raises an informative ImportError citing the
   underlying dlopen failure via _raise_if_placo_unusable().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(kinematics): hoist _placo_runtime_error to module scope for mypy

Mypy walks the TYPE_CHECKING branch in which the runtime else-block is
not executed, so _placo_runtime_error was only defined at runtime and
mypy reported 'Name "_placo_runtime_error" is not defined' on the
three references inside _raise_if_placo_unusable. Declare the symbol
unconditionally at module scope with a default of None; the runtime
import-failure branch still assigns to it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* style(kinematics): drop verbose comments around placo import guard

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 12:03:07 +02:00
Haoming Song
9f437d86b6 fix(groot): align GR00TN15Config with transformers config dataclasses (#3606)
* fix(gr00t): fix gr00t config dataclass init TypeError

* fix(groot): guard strict config decorator without transformers for passing CI

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-05-22 10:31:04 +02:00
Haoming Song
b74a551d38 fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash

* fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions

* fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete

* feat(test): add compile test and benchamrk for pi0 and pi05

* feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
2026-05-22 10:29:34 +02:00
Nikodem Bartnik
c0a2e9814d fix examples (#3623)
- Fixed broken API examples in Lerobot Imitation Learning Documentation
- Teleoperation with cameras improved by adding a fixed frequency in the loop (without it the cameras feed gets very slow)
- Wrapped record example script in main() to avoid problems on Mac
- Previously teleoperation example was using SO-ARM and teleoperation with cameras was using Koch. I changed it to use SO-ARM in all of the examples.
- Added section on how to train with HF Jobs - CLI and Python examples
- Replaced lerobot-record with lerobot-rollout in policies examples
2026-05-21 22:14:07 +02:00
Khalil Meftah
bac4f61eae refactor: support custom progress parquet overlays (#3640) 2026-05-21 14:32:10 +02:00
Virgileboat
f4b834844e Feat/clean can bus (#3526)
* change timeout  for handshake

* enforce last state read when querry

* change import order

* fix(motors): flush stale robstride RX and harden feedback drain

* robstride: remove redundant timeout and max_messages casts

* bugfix + %-style

* update exception catch
2026-05-21 11:44:04 +02:00
Roham Z. Nobari
dfdc48a7f1 fix(datasets): bound VideoDecoderCache to prevent OOM on large datasets (#3614)
VideoDecoderCache used an unbounded dict keyed on absolute path, with no
eviction in the standard LeRobotDataset path. With shuffled iteration over
datasets that have many distinct mp4 files, every DataLoader worker
accumulated one cached (VideoDecoder, fsspec file handle) pair per distinct
path it had ever touched. Per-entry cost is ~3-5 MB of host RAM plus one
open FD; at ~8 k entries this is roughly 30 GB per worker.

This was hit in the wild during a SmolVLA training run on a 4,195-episode
SO-101 dataset (8,390 mp4s, two cameras per episode). dmesg showed
anon-rss climbing to 34.9 GB on a single pt_data_worker before the OOM
killer fired ~30 min into training; with --num_workers=8 the per-worker
peak halved to 17.9 GB, which is the expected inverse-scaling signature
when the leak is per-decode and the workload is split across workers. The
working workaround on the affected platform was --dataset.video_backend=pyav,
because the pyav path opens/closes per call and never touches this cache.

Switch the backing store to an OrderedDict and evict LRU entries when the
cap is reached, closing the evicted file handle inside the lock so we do
not leak FDs either. Default cap is DEFAULT_DECODER_CACHE_SIZE = 100,
overridable via LEROBOT_VIDEO_DECODER_CACHE_SIZE or by passing max_size=
to the constructor; max_size=None restores the legacy unbounded behaviour
for callers that need it.

Validation on the original failing workload (decode_video_frames_torchcodec
called over real mp4s from the affected SO-101 dataset):

  unbounded:    300 files  ->  +1087 MB host RSS,  cache=300, still climbing
  cap=50:       500 files  ->   +266 MB host RSS,  cache=50,  stable
  cap=50:      2000 calls  ->   +312 MB host RSS,  cache=50,  stable
  cap=100:     1000 calls  ->   +470 MB host RSS,  cache=100, stable

Three independent seeded runs at cap=50 agreed to within 1% (263 / 266 /
265 MB delta), and the 2000-call multi-pass run shows RSS plateaus after
the cap is reached instead of drifting.

Tests in tests/datasets/test_video_decoder_cache.py cover:
default-is-bounded, size cap, LRU ordering, FD close on eviction, FD close
on clear(), cache-hit invariance, max_size=None fallback, and env-var
override. No regressions in test_video_encoding.py, test_streaming.py, or
test_dataset_reader.py (73 prior tests still pass alongside the 8 new ones).
2026-05-19 16:54:25 +02:00
四七
6a8878a639 fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344)
* fix(datasets): normalize shape=(1,) numeric values before save

* test(datasets): cover shape=(1,) int/bool and finalize

Co-authored-by: Copilot <copilot@github.com>
2026-05-19 16:53:19 +02:00
Caroline Pascal
d38eb89f71 feat(video re-encoding): Adding utility and dataset edition tool for video re-encoding (#3611)
* feat(utility): adding video re-encode utility

* feat(edit): adding a new lerobot-edit-dataset tool to re-encode all the videos of a dataset

* chore(format): formatting code

* chore(review): fix Claude reviews

* test(reencode dataset): adding missing test for reencode dataset
2026-05-19 14:46:14 +02:00
Pepijn
7ab4936b1b Add extensive language support (#3467)
* Add extensive language support

* Address review: split persistent/event schemas, drop event timestamps

- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Add docstrings to all new helpers; revert uv.lock

Covers private helpers in recipe.py, language.py, language_render.py,
and render_messages_processor.py. Also reverts uv.lock to main (it was
re-generated by `uv run` during local checks).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat(language): add motion (persistent) and trace (event-only) styles

Promote the previously-reserved motion/trace styles to first-class core
styles. motion routes to language_persistent (it tracks robot state over
time); trace routes to language_events (single-moment annotations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat(language): per-camera tagging on view-dependent styles

Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.

- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
  to both Arrow struct types and the HF datasets feature mappings;
  introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
  `is_view_dependent_style` and `validate_camera_field` helpers (camera
  required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
  resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
  `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
  VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
  Without a `camera` filter, multi-row matches keep raising the existing
  ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
  `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
  the matching image block), keeping the original 0.20 budget and
  documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
  `is_view_dependent_style`, `validate_camera_field` (both required and
  forbidden directions), per-camera `emitted_at` filtering, and the
  ambiguity error when two cameras emit `(vqa, assistant)` at the same
  timestamp without a `camera=` filter. RenderMessagesStep + dataset
  passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
  the per-camera resolver pattern, and the canonical recipe convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(language): drop motion from VIEW_DEPENDENT_STYLES

Motion primitives are described in robot-frame (joint / Cartesian) terms,
not pixel space, so they are camera-agnostic. Only `vqa` (event) and
`trace` (event, pixel-trajectory) are view-dependent.

The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry —
the validator, resolver, and HF feature mapping behave identically across
the two columns regardless of which styles populate `camera` today —
but persistent rows now always have `camera=None` in practice.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat(language): task_aug style + automatic ${task} rephrasing rotation

Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.

- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
  ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
  ``language_persistent`` so PR 2 writers route it correctly.

- ``language_render.py``: ``_resolve_task`` now consults the persistent
  slice for rows of ``style="task_aug", role="user"``. When any exist
  it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
  Python's randomized hash) so an epoch sees every rephrasing of every
  episode while the same sample still resolves identically across
  reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
  no rephrasings are present, so existing datasets and unannotated runs
  keep their behaviour. Explicit ``task=`` overrides still win.

- Tests: rephrasing coverage across samples, determinism on repeat
  ``sample_idx``, fallback when persistent has no ``task_aug`` rows,
  and explicit override priority.

Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools

Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so
datasets can declare which tools are available (today: just ``say``;
tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant
fills in for unannotated datasets so chat-template consumers don't
have to special-case anything.

Three pieces:

- ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS``
  constants. Single source of truth — PR 2's writer and PR 3's
  runtime tool registry will both import from here instead of
  duplicating the dict.
- ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property
  reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``.
  Returns deep-copied dicts so callers can mutate the result safely.
- ``docs/source/tools.mdx``: spec page covering the catalog, per-row
  invocations, and the three-step "how to add a new tool" workflow
  (declare schema, implement, register). Linked from the docs
  toctree under the Datasets section.

This lays the groundwork for PR 2's pipeline writing the catalog out
during annotation, and PR 3's ``src/lerobot/tools/`` package shipping
runnable implementations (one file per tool — first up:
``say.py`` wrapping Kyutai's pocket-tts).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Apply ruff and prettier formatting after merge

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* refactor(language): unify resolver dispatch and prune redundant test scaffolding

* Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`;
  only `emitted_at` actually consults events. The dispatcher in
  `_resolve_spec` now passes events conditionally.
* Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a
  single `_row_sort_key` and drop the `sort_key` parameter from
  `_select_one`. Event rows lack `timestamp` (it is implicit in the
  frame) and now default to `0.0` for sort purposes — the
  `(style, role)` tiebreaker is unchanged.
* Inline `_select_latest` into `active_at` (its only caller).
* Collapse `emitted_at`'s dual-branch into one `_select_one` call.
* Tighten `_validate_persistent_resolver` to a single
  `column_for_style(style) != LANGUAGE_PERSISTENT` check.
* Parameterize `test_per_camera_blend_renders_both_views` over the two
  cameras and factor the sub-recipe builder into `_vqa_subrecipe` so
  the test no longer hand-rolls two near-identical recipe blocks.

Net -98 LOC; behavior, public resolver names, and test expectations
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(language): always raise on ambiguous resolver matches

`_select_one` previously skipped its ambiguity check whenever any of
`role`/`tool_name`/`camera` was set, on the assumption that the caller
had already pinned down a unique row. That left a real ambiguity hole
for VQA: with two cameras emitting `(vqa, assistant)` at the same
frame, `emitted_at(..., role="assistant")` silently picked the first
sorted row instead of telling the recipe to add `camera=...`. The
existing `test_emitted_at_raises_on_ambiguous_per_camera_vqa` test
already encoded the desired behavior.

Tighten the check: any time `len(rows) > 1` we now raise with the
selectors echoed back, so users see exactly which fields they passed
and that more is needed to disambiguate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: fix CI — collapse short ValueError to one line, refresh uv.lock

* `ruff format` on CI (newer version) wants the short `camera=None`
  ValueError on a single line.
* `uv.lock` was stale relative to `pyproject.toml`'s `datasets>=4.7.0`
  pin (and picked up upstream `s390x` marker fixes for cuda packages).
  CI runs `uv sync --locked` which rejected the divergence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(language): keep base install green — drop processor re-export, gate dataset-extra tests

`lerobot.processor` re-exported `RenderMessagesStep` at the package
level, so importing anything from `lerobot.processor` pulled in
`lerobot.datasets.language` → `lerobot.datasets/__init__.py` →
`require_package("datasets")`, which fails in the Tier 1 base install
that intentionally omits the `[dataset]` extra. The chain bricked
collection for unrelated suites (`tests/policies/pi0_pi05/...`,
`tests/envs/...`, etc.).

* Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The
  only consumer (the test) already imports from the submodule.
  Document the deliberate omission in the module docstring.
* Add `pytest.importorskip("datasets", ...)` (and `pandas` where
  needed) at the top of the four PR-added tests that exercise the
  language stack:
  - tests/datasets/test_language.py
  - tests/datasets/test_language_render.py
  - tests/processor/test_render_messages_processor.py
  - tests/utils/test_collate.py

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(language): address review — tools accessor, motion docs, conditional collate

* **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo`
  had no `tools` field, so `from_dict` silently dropped the key (it
  warned about unknown fields then discarded them) and the property
  always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None`
  to the dataclass; `to_dict()` drops it when unset so existing
  datasets keep a clean `info.json`. Fixed the accessor to read
  `self.info.tools` (the previous `.get(...)` would have raised
  AttributeError on the dataclass anyway). Added regression tests:
  fallback when absent, round-trip from disk, and round-trip
  through `DatasetInfo.from_dict` / `to_dict`.

* **`motion` is not view-dependent — fix the docs.** The mdx claimed
  rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES
  = {"vqa", "trace"}` and the validator agrees: motion primitives are
  joint/Cartesian-frame, not pixel-space. Updated both call-out
  paragraphs in `language_and_recipes.mdx`.

* **Conditional `collate_fn` swap.** Added `meta.has_language_columns`
  and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it,
  so non-language datasets keep PyTorch's `default_collate`. Also
  added a pass-through test in `test_collate.py` that asserts on a
  plain tensor batch the custom collate matches `default_collate`
  key-for-key, plus a test for the `None`-sample drop path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* review: dedupe regex, centralize column names, harden collate, more tests

* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
  `recipe.py` and `language_render.py`. Promote to module-level
  `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
  template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
  hardcoded `{"language_persistent", "language_events"}` literals at
  two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
  rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
  silently filtered language fields from samples that didn't have
  them, which would hand downstream consumers a preserved list
  shorter than the tensor batch. Now: if any sample carries a key,
  every sample in the batch must carry it; otherwise raise a
  `ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
  or multi-element list fell through and triggered confusing
  `float([])` errors downstream. Now raises `ValueError` with the
  actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
  of `key = {... if ... else {}}` plus an 11-line splat dict with a
  single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
  comment. Add a docstring explaining it's an intentional extension
  point: downstream modules append project-local styles before
  `column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
  page referenced `src/lerobot/tools/`, `registry.py`, and
  `get_tools(meta)` — none exist in this PR. Added a callout at the
  start of "How to add your own tool" plus a note on the
  implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
  validation.** `test_recipe.py` grew from 1 case to 12 covering:
  blend-or-messages exclusivity, target-turn requirement, blend
  emptiness, weight presence/positivity, nested-blend rejection,
  `from_dict` with nested blends, `from_yaml` / `load_recipe`
  agreement, top-level non-mapping rejection. Added a malformed-row
  test for `_normalize_rows` that asserts non-dict entries raise
  `TypeError`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction

* **Float tolerance in `emitted_at` for persistent styles.** The
  ``_timestamp(row) == t`` exact-equality check silently missed any
  caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``)
  even though the parquet timestamp would only differ by ULPs. Added
  ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance``
  instead, with a docstring explaining why exact equality wasn't
  enough and why 0.1 s is safe at typical 30–100 Hz control rates.
  Test asserts the new behavior at half-window (matches) and
  double-window (no match) using the constant so it stays in sync.

* **`MessageTurn.stream` is required at construction.** It was typed
  ``MessageStream | None = None`` so YAML could omit ``stream:`` and
  pass the dataclass invariant — but ``_validate_rendered`` rejected
  ``None`` streams later, surfacing the error at the first sample
  instead of at recipe load. Now ``__post_init__`` raises
  ``ValueError`` if ``stream`` is ``None``, with the list of valid
  streams in the message. The redundant late-stage check in
  ``_validate_rendered`` is replaced with a one-line comment that
  cites the upstream invariant. Test pins the new construction-time
  rejection.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs(tools): drop follow-up-PR references

Reword the two callouts in `tools.mdx` to describe the runtime layer
in present tense ("not part of the catalog layer shipped today",
"those modules don't yet exist in the tree") instead of pointing at a
specific follow-up PR. Keeps the doc honest about what works now
without coupling it to a particular release order.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* review: address CarolinePascal feedback

- language timestamps: float64 -> float32 to match LeRobotDataset frame
  timestamps (Arrow struct + HF feature)
- dataset_metadata: hoist `.language` imports to module top — language.py
  has no lerobot imports, so there is no circular-import risk
- dataset_metadata: add a `meta.tools` setter that persists the catalog to
  info.json and reloads `meta.info`
- feature_utils: validate the `language` dtype instead of returning "" —
  warn (non-fatal) when a non-empty value is written at record time
- centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`,
  shared by render_messages_processor and language_render
- docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections,
  which describe recipe bindings rather than dataset layout
- language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change
  on a human-action timescale, not the camera frame rate

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:46:11 +02:00
von Neumann 101
ca8c60a0ed Set OpenCV fourcc after size and fps (#3620)
* Set OpenCV fourcc after size and fps

* Set OpenCV fourcc last on Windows

* Add comment explaining DSHOW fourcc ordering
2026-05-19 14:06:41 +02:00
Pepijn
3c15fd8537 feat(robots): natively integrate Seeed Studio reBot B601-DM arm (#3624)
* feat(robots): natively integrate Seeed Studio reBot B601-DM arm

Add first-class LeRobot support for the Seeed Studio reBot arm, replacing
the out-of-tree `lerobot-robot-seeed-b601` / `lerobot-teleoperator-rebot-arm-102`
plugin packages.

New devices:
- robot `rebot_b601_follower` — single-arm B601-DM follower (6-DOF + gripper,
  Damiao CAN motors via `motorbridge`)
- robot `bi_rebot_b601_follower` — bimanual follower composing two single arms
- teleoperator `rebot_102_leader` — single-arm StarArm102 / reBot Arm 102 leader
  (FashionStar UART servos via `motorbridge-smart-servo`)
- teleoperator `bi_rebot_102_leader` — bimanual leader composing two single arms

The bimanual variants reuse the single-arm classes and namespace each arm's
observation/action keys with `left_` / `right_` prefixes, so a bimanual
StarArm102 leader can teleoperate a bimanual reBot B601 follower.

Optional SDK imports are guarded; a `rebot` extra installs `motorbridge` and
`motorbridge-smart-servo`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: add reBot B601-DM calibration & dual-arm teleoperation guide

Add docs/source/rebot_b601.mdx covering single-arm and bimanual
calibration and teleoperation for the reBot B601-DM follower and
reBot Arm 102 leader, with zero-position reference images from the
Seeed Studio wiki. Register the page in the docs toctree.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: fix reBot B601 MDX build (move JSON example out of <Tip>)

The doc-builder parses `{...}` inside MDX component children as a
Svelte expression, so the joint_directions JSON example broke the
build. Move it into a top-level fenced code block.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: apply prettier formatting to reBot B601 page

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: remove duplicate colocated reBot B601 page

docs/source/rebot_b601.mdx is the canonical, toctree-registered page;
the colocated rebot_b601.md was a redundant thinner copy.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs: clarify 6-DOF leader fallback comment in reBot B601 follower

Explain that holding wrist_yaw at zero is what lets a 6-DOF leader
(e.g. so100_leader / so101_leader) teleoperate the 7-DOF follower.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* refactor: address Caroline's PR review on reBot B601 integration

- leader: remove _validate_config (no other lerobot device validates its
  config; a key mismatch now surfaces as a plain KeyError)
- leader: simplify _round_to_valid_range to direct modular arithmetic
  instead of a bidirectional search loop
- leader: inline the single-use _clamp helper
- follower & leader: write MotorCalibration range_min/range_max from the
  configured joint_limits / joint_ranges instead of a fixed [-90, 90]
- docs: add a "Find the USB ports" section (lerobot-find-port) and move
  the brltty/permissions tip there; link the OpenArm page for SocketCAN
  adapter configuration

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 19:49:21 +02:00
111 changed files with 8489 additions and 2646 deletions

View File

@@ -0,0 +1,172 @@
- sections:
- local: index
title: LeRobot
- local: installation
title: Installation
- local: cheat-sheet
title: Cheat sheet
title: Get started
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: bring_your_own_policies
title: Adding a Policy
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
- local: hil_data_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: hardware_guide
title: Compute Hardware Guide
- local: torch_accelerators
title: PyTorch accelerators
title: "Compute & Hardware"
- sections:
- local: lerobot-dataset-v3
title: Using LeRobotDataset
- local: porting_datasets_v3
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: language_and_recipes
title: Language Columns and Recipes
- local: tools
title: Tools
- local: video_encoding_parameters
title: Video encoding parameters
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
- sections:
- local: act
title: ACT
- local: smolvla
title: SmolVLA
- local: pi0
title: π₀ (Pi0)
- local: pi0fast
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"
- sections:
- local: sarm
title: SARM
title: "Reward Models"
- sections:
- local: inference
title: Policy Deployment (lerobot-rollout)
- local: async
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
title: "Inference"
- sections:
- local: envhub
title: Environments from the Hub
- local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac)
title: "Simulation"
- sections:
- local: adding_benchmarks
title: Adding a New Benchmark
- local: libero
title: LIBERO
- local: libero_plus
title: LIBERO-plus
- local: metaworld
title: Meta-World
- local: robotwin
title: RoboTwin 2.0
- local: robocasa
title: RoboCasa365
- local: robocerebra
title: RoboCerebra
- local: robomme
title: RoboMME
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: vlabench
title: VLABench
title: "Benchmarks"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors
- local: debug_processor_pipeline
title: Debug your processor pipeline
- local: implement_your_own_processor
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
- local: action_representations
title: Action Representations
title: "Robot Processors"
- sections:
- local: so101
title: SO-101
- local: so100
title: SO-100
- local: koch
title: Koch v1.1
- local: lekiwi
title: LeKiwi
- local: hope_jr
title: Hope Jr
- local: reachy2
title: Reachy 2
- local: unitree_g1
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
- local: omx
title: OMX
- local: openarm
title: OpenArm
- local: rebot_b601
title: reBot B601-DM
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: notebooks
title: Notebooks
- local: feetech
title: Updating Feetech Firmware
- local: damiao
title: Damiao Motors and CAN Bus
title: "Resources"
- sections:
- local: contributing
title: Contribute to LeRobot
- local: backwardcomp
title: Backward compatibility
title: "About"

View File

@@ -1,168 +1,214 @@
# LeRobot documentation table of contents
#
# Ordering principle: gentle onboarding first, advanced/custom work last.
# Within each top-level section the same rule applies — concept/overview pages
# before reference/per-item pages.
#
# Pages marked "NEW (to create)" do not yet exist as .mdx files; they are
# placeholders for the redesign and must be authored before the docs build.
- sections:
- local: index
title: LeRobot
title: 🤗 LeRobot
- local: quickstart # NEW (to create) — 15-min zero-to-trained-ACT path
title: Quickstart
- local: installation
title: Installation
- local: core_concepts # NEW (to create) — datasets, policies, processors, robots, envs in one mental model
title: Core concepts
- local: cheat-sheet
title: Cheat sheet
title: Command cheat sheet
title: Get started
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: bring_your_own_policies
title: Adding a Policy
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
title: Imitation learning end-to-end
- local: hil_data_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
title: Human-in-the-loop data collection
- local: inference
title: Deploying a trained policy
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
title: Matching dataset keys to a policy (rename map)
title: Your first project
- sections:
- local: hardware_guide
title: Compute Hardware Guide
title: Compute hardware guide
- local: torch_accelerators
title: PyTorch accelerators
title: "Compute & Hardware"
- local: multi_gpu_training
title: Multi-GPU training
- local: peft_training
title: Parameter-efficient fine-tuning (LoRA)
title: Training
- sections:
- local: lerobot-dataset-v3
title: Using LeRobotDataset
- local: porting_datasets_v3
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
title: Dataset tools
- local: language_and_recipes
title: Language columns & recipes
- local: tools
title: Tool calls in datasets
- local: video_encoding_parameters
title: Video encoding parameters
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
title: Streaming video encoding
- local: porting_datasets_v3
title: Porting datasets to v3
title: Datasets
- sections:
- local: act
title: ACT
- local: smolvla
title: SmolVLA
- local: pi0
title: π₀ (Pi0)
- local: pi0fast
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"
- local: policies_overview # NEW (to create) — concept hub + "choose a policy" decision guide
title: Choosing a policy
- sections:
- local: act
title: ACT
- local: smolvla
title: SmolVLA
- local: pi0
title: π₀ (Pi0)
- local: pi0fast
title: π₀-FAST
- local: pi05
title: π₀.₅ (Pi05)
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: walloss
title: WALL-OSS
- local: multi_task_dit
title: Multitask DiT
title: Policy reference
title: Policies
- sections:
- local: sarm
title: SARM
title: "Reward Models"
- sections:
- local: inference
title: Policy Deployment (lerobot-rollout)
- local: async
title: Use Async Inference
title: Async inference
- local: rtc
title: Real-Time Chunking (RTC)
title: "Inference"
title: Real-time chunking (RTC)
title: Real-time deployment
- sections:
- local: hilserl
title: Train a robot with RL (HIL-SERL)
- local: hilserl_sim
title: Train RL in simulation
- local: sarm
title: SARM reward model
title: Reinforcement learning
- sections:
- local: envhub
title: Environments from the Hub
- local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac)
title: "Simulation"
- sections:
- local: adding_benchmarks
title: Adding a New Benchmark
- local: libero
title: LIBERO
- local: libero_plus
title: LIBERO-plus
- local: metaworld
title: Meta-World
- local: robotwin
title: RoboTwin 2.0
- local: robocasa
title: RoboCasa365
- local: robocerebra
title: RoboCerebra
- local: robomme
title: RoboMME
title: LeIsaac — control & train in sim
- local: envhub_isaaclab_arena
title: NVIDIA IsaacLab Arena Environments
- local: vlabench
title: VLABench
title: "Benchmarks"
title: NVIDIA IsaacLab Arena environments
- sections:
- local: libero
title: LIBERO
- local: libero_plus
title: LIBERO-plus
- local: metaworld
title: Meta-World
- local: robotwin
title: RoboTwin 2.0
- local: robocasa
title: RoboCasa365
- local: robocerebra
title: RoboCerebra
- local: robomme
title: RoboMME
- local: vlabench
title: VLABench
title: Benchmark suites
title: Simulation & benchmarks
- sections:
- local: introduction_processors
title: Introduction to Robot Processors
- local: debug_processor_pipeline
title: Debug your processor pipeline
- local: implement_your_own_processor
title: Implement your own processor
title: Introduction to processors
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
title: Processors for robots & teleoperators
- local: env_processor
title: Environment Processors
title: Environment processors
- local: action_representations
title: Action Representations
title: "Robot Processors"
title: Action representations
- local: debug_processor_pipeline
title: Debugging a pipeline
- local: implement_your_own_processor
title: Implementing your own processor
title: Processors
- sections:
- local: so101
title: SO-101
- local: so100
title: SO-100
- local: koch
title: Koch v1.1
- local: lekiwi
title: LeKiwi
- local: hope_jr
title: Hope Jr
- local: reachy2
title: Reachy 2
- local: unitree_g1
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
- local: omx
title: OMX
- local: openarm
title: OpenArm
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: so101
title: SO-101
- local: so100
title: SO-100
- local: koch
title: Koch v1.1
- local: omx
title: OMX
- local: openarm
title: OpenArm
title: Low-cost arms
- sections:
- local: lekiwi
title: LeKiwi
- local: earthrover_mini_plus
title: Earth Rover Mini
title: Mobile platforms
- sections:
- local: hope_jr
title: Hope Jr
- local: reachy2
title: Reachy 2
- local: unitree_g1
title: Unitree G1
title: Bimanual & humanoid
- sections:
- local: rebot_b601
title: reBot B601-DM
title: Research & industrial
title: Supported robots
- sections:
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: notebooks
title: Notebooks
- local: phone_teleop
title: Phone teleoperation
- local: feetech
title: Updating Feetech Firmware
title: Feetech firmware update
- local: damiao
title: Damiao Motors and CAN Bus
title: "Resources"
title: Damiao motors & CAN bus
title: Sensors, teleop & motors
- sections:
- local: contributing
title: Contribute to LeRobot
- local: integrate_hardware
title: Bring your own hardware
- local: bring_your_own_policies
title: Add a new policy
- local: adding_benchmarks
title: Add a new benchmark
title: Extend LeRobot
- sections:
- local: troubleshooting # NEW (to create) — common errors: USB, calibration drift, CUDA OOM, video decoding…
title: Troubleshooting & FAQ
- local: glossary # NEW (to create) — episode, action chunk, leader/follower, teleop, processor…
title: Glossary
- local: notebooks
title: Example notebooks
- local: backwardcomp
title: Backward compatibility
title: "About"
title: Reference
- sections:
- local: contributing
title: Contributing to LeRobot
title: About

View File

@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
```bash
lerobot-record \
--robot.type=so100_follower \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/act_policy \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_robot \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true \
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=${HF_USER}/act_policy
--task="Your task description" \ # can be skipped for ACT
--duration=60
```

View File

@@ -1,277 +0,0 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation

View File

@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
```bash
lerobot-record \
lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
@@ -119,14 +121,12 @@ lerobot-record \
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
--duration=600
```
## License

View File

@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem58760431541",
id="my_red_robot_arm",
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
robot = SO101Follower(robot_config)
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoption id="Command">
```bash
lerobot-teleoperate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=my_awesome_leader_arm \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem5AB90687491 \
--robot.id=my_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem5AB90689011 \
--teleop.id=my_leader_arm \
--display_data=true
```
</hfoption>
@@ -122,34 +122,48 @@ lerobot-teleoperate \
<!-- prettier-ignore-start -->
```python
import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
camera_config = {
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
}
robot_config = KochFollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="my_red_robot_arm",
cameras=camera_config
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
teleop_config = KochLeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
robot = KochFollower(robot_config)
teleop_device = KochLeader(teleop_config)
init_rerun(session_name="teleoperation")
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
robot.connect()
teleop_device.connect()
TARGET_HZ = 30
TIME_PER_FRAME = 1.0 / TARGET_HZ
while True:
start_time = time.perf_counter()
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
if sleep_time > 0:
time.sleep(sleep_time)
```
<!-- prettier-ignore-end -->
@@ -202,10 +216,11 @@ lerobot-record \
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
# Create robot configuration
robot_config = SO100FollowerConfig(
id="my_awesome_follower_arm",
cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
},
port="/dev/tty.usbmodem58760434471",
)
teleop_config = SO100LeaderConfig(
id="my_awesome_leader_arm",
port="/dev/tty.usbmodem585A0077581",
)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop = SO100Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
def main():
# Create robot configuration
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
# Initialize the robot and teleoperator
robot = SO101Follower(robot_config)
teleop = SO101Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
dataset.save_episode()
episode_idx += 1
# finalize dataset
log_say("Finalizing dataset...")
dataset.finalize()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
if __name__ == "__main__":
main()
```
<!-- prettier-ignore-end -->
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
episode_idx = 0
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
robot = SO100Follower(robot_config)
robot.connect()
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Train using Hugging Face Jobs
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
To run the training use this command:
<hfoptions id="train_with_hf_jobs">
<hfoption id="Command">
```bash
hf jobs run \
--flavor a10g-small \
--timeout 4h \
--secrets HF_TOKEN \
huggingface/lerobot-gpu:latest \
-- \
python -m lerobot.scripts.lerobot_train \
--dataset.repo_id=username/dataset \
--policy.type=act \
--steps=5000 \
--batch_size=16 \
--policy.device=cuda \
--policy.repo_id=username/your_policy \
--log_freq=100
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from huggingface_hub import run_job, get_token
run_name = "act_so101_hf_jobs"
dataset_id = "username/dataset"
user_hub_id = "username"
command_args = [
"python", "-m", "lerobot.scripts.lerobot_train",
"--dataset.repo_id", dataset_id,
"--policy.type", "act",
"--steps", "5000",
"--batch_size", "16",
"--num_workers", "4",
"--policy.device", "cuda",
"--log_freq", "100",
"--save_freq", "1000",
"--save_checkpoint", "true",
"--wandb.enable", "false",
"--policy.repo_id", f"{user_hub_id}/{run_name}"
]
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
job_info = run_job(
image="huggingface/lerobot-gpu:latest",
command=command_args,
flavor="a10g-small",
timeout="4h",
secrets={"HF_TOKEN": get_token()}
)
print("\n🚀 Job successfully launched!")
print(f"🔹 Job ID: {job_info.id}")
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
For longer training sessions increase the timeout.
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:

View File

@@ -0,0 +1,147 @@
# Language columns and recipes
Most LeRobot datasets ship with a single `task` string per episode — fine for
short, single-instruction skills, but not enough for the longer-horizon,
multi-modal robot policies the field is moving toward (high-level planning,
memory, interjections, VQA, tool use). To support those policies without
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
language columns and a small recipe layer that turns those rows into
chat-style training samples on the fly.
The design splits cleanly into three layers:
1. **Data in the dataset** — language annotations stored next to frames in
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
and `language_events`). Datasets without these columns keep their existing
behavior.
2. **Recipe** — a YAML file that declares which annotation rows to bind and
how to lay them out as chat turns (`role`, `content`, optional images,
optional tool calls). Recipes are pure config; no Python required to add a
new one.
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
recipe against the per-frame annotations and emits HF-style `messages` plus
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
that policy processors consume.
This page describes each layer in turn.
## Layer 1 — language columns in the dataset
The two optional columns live next to frame data in
`data/chunk-*/file-*.parquet`:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape (event rows omit `timestamp` because the
frame the row sits on already provides it):
```text
role: string
content: string | null
style: string | null
timestamp: float32 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null
```
The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
the matching `observation.images.*` feature key. Rows of every other style —
including `motion`, which describes robot-frame primitives in joint / Cartesian
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
### Architecture
The language stack itself has three internal modules backing layer 1:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
## Layer 2 — recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
declare which annotation rows to pull (via `bindings`) and how to compose them
into chat turns (`messages`).
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
time, exactly one branch is selected deterministically from the sample index,
so different frames train different objectives (e.g. memory updates vs.
low-level execution vs. VQA) without any Python wiring.
### Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
### View-dependent resolution
For view-dependent styles (`vqa` and `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two
matches and raise an ambiguity error. Recipes consume each camera through its
own binding plus a matching image block, e.g.
```yaml
ask_vqa_top:
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- { type: image, feature: observation.images.top }
- { type: text, text: "${vqa_query}" }
- {
role: assistant,
content: "${vqa}",
stream: high_level,
target: true,
if_present: vqa,
}
```
Add one such sub-recipe per camera the dataset records.
## Layer 3 — training format
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.

219
docs/source/quickstart.mdx Normal file
View File

@@ -0,0 +1,219 @@
# Quickstart
This is the **shortest path** from an unboxed SO-101 to a policy that drives your own robot. Every step is copy-paste; replace the **`<placeholders>`** with the values for your setup.
By the end you will have:
- A calibrated SO-101 leader + follower pair.
- A dataset of 30 episodes pushed to the Hugging Face Hub.
- A trained ACT policy (~20k steps) running on your robot via `lerobot-rollout`.
> [!NOTE]
> **How long will this take?**
> Recording 30 episodes is roughly 3060 minutes of teleoperation. Training ACT for 20k steps takes ~1.5h on an A100, a few hours on a laptop RTX 3060, longer on Apple Silicon (`mps`). The commands themselves are quick — most of the wall-clock is data collection and training.
> [!TIP]
> If you only want to **understand the codebase** or **train on an existing dataset without hardware**, this page isn't for you. Read [Core concepts](./core_concepts) first, then jump to [Imitation learning end-to-end](./il_robots).
---
## Before you start
You need:
- An **assembled SO-101 leader + follower pair**. If your robot is not assembled yet, follow the [SO-101 assembly guide](./so101) and come back here.
- **One or two cameras** (USB webcam works fine).
- A **CUDA GPU with ≥ 6 GB VRAM** (ACT is light — a laptop RTX 3060 works). Apple Silicon (`mps`) and CPU are supported but slower. See the [compute hardware guide](./hardware_guide) for sizing.
- A **Hugging Face account** — datasets and the trained policy will be pushed to your Hub.
If any of the above is missing, fix it first; the rest of the page assumes it.
---
## Step 1 — Install LeRobot
Follow the full [Installation Guide](./installation) for environment setup, then add the SO-101 motor stack and log in to the Hub:
```bash
pip install 'lerobot[feetech]'
git lfs install && git lfs pull
hf auth login # paste a token from https://huggingface.co/settings/tokens
```
Sanity check — the CLI entry points should be available:
```bash
lerobot-find-port --help
```
---
## Step 2 — Identify USB ports and motor IDs
Plug **only the follower arm** in (USB + power) and run:
```bash
lerobot-find-port
```
When prompted, unplug it and press Enter. Note the printed port — that's your `<FOLLOWER_PORT>`. Repeat with only the **leader arm** plugged in to get `<LEADER_PORT>`.
> [!TIP]
> On Linux, USB ports look like `/dev/ttyACM0`; on macOS like `/dev/tty.usbmodem...`. On Linux you may need `sudo chmod 666 /dev/ttyACM0` to grant access.
If your motors are brand-new (or repurposed), set their IDs and baudrate **once per arm**:
```bash
lerobot-setup-motors --robot.type=so101_follower --robot.port=<FOLLOWER_PORT>
lerobot-setup-motors --teleop.type=so101_leader --teleop.port=<LEADER_PORT>
```
The script walks you through connecting motors one at a time. Full details: [SO-101 → Configure the motors](./so101#configure-the-motors).
---
## Step 3 — Calibrate
Center every joint roughly in the middle of its range, then run:
```bash
lerobot-calibrate \
--robot.type=so101_follower \
--robot.port=<FOLLOWER_PORT> \
--robot.id=my_follower
lerobot-calibrate \
--teleop.type=so101_leader \
--teleop.port=<LEADER_PORT> \
--teleop.id=my_leader
```
After pressing Enter, sweep each joint through its full range of motion, then press Enter again to finish.
> [!WARNING]
> The `--robot.id` / `--teleop.id` values (`my_follower`, `my_leader`) become the **calibration keys**. Reuse the same IDs in every later command — that's how LeRobot finds the calibration on disk.
Watch the [calibration video](./so101#calibrate) if anything is unclear.
---
## Step 4 — Teleoperate (sanity check, no recording)
Before recording anything, confirm the leader drives the follower correctly:
```bash
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=<FOLLOWER_PORT> \
--robot.id=my_follower \
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
--teleop.type=so101_leader \
--teleop.port=<LEADER_PORT> \
--teleop.id=my_leader \
--display_data=true
```
A Rerun window should open showing the camera feed and joint angles. Move the leader — the follower should mirror it in real time. If it doesn't, see [Troubleshooting & FAQ](./troubleshooting).
Don't know which camera index is which? Run `lerobot-find-cameras` — it saves a frame from each detected camera so you can pick the right one.
---
## Step 5 — Record a dataset (30 episodes)
Now record demonstrations. Pick a short, repeatable task (e.g. *"put the red brick in the bowl"*). The dataset is pushed to the Hub under your username:
```bash
export HF_USER=<your-hf-username>
lerobot-record \
--robot.type=so101_follower \
--robot.port=<FOLLOWER_PORT> \
--robot.id=my_follower \
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30} }" \
--teleop.type=so101_leader \
--teleop.port=<LEADER_PORT> \
--teleop.id=my_leader \
--dataset.repo_id=${HF_USER}/so101_quickstart \
--dataset.num_episodes=30 \
--dataset.single_task="Put the red brick in the bowl" \
--dataset.streaming_encoding=true \
--display_data=true
```
**Keyboard controls during recording:**
- **`→` (Right Arrow)** — save the current episode and move to the next.
- **`←` (Left Arrow)** — discard the current episode and retry.
- **`Esc`** — stop, encode videos, and upload to the Hub.
> [!TIP]
> **Quality beats quantity.** 30 clean, varied episodes (different brick positions, lighting, camera shake) train a much better policy than 100 identical ones. Move the object around. Vary your speed slightly.
When you're done, your dataset lives at `https://huggingface.co/datasets/${HF_USER}/so101_quickstart`. You can preview it in the browser. For deeper recording options (resume, multiple tasks, custom processors), see [Imitation learning end-to-end → Record](./il_robots#record-a-dataset).
---
## Step 6 — Train ACT
ACT (Action Chunking Transformer) is the right default for a first run — small, fast, and works well on 30 episodes.
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/so101_quickstart \
--policy.type=act \
--output_dir=outputs/train/act_so101_quickstart \
--job_name=act_so101_quickstart \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/act_so101_quickstart \
--steps=20000 \
--wandb.enable=true
```
A few notes:
- Replace `--policy.device=cuda` with `mps` on Apple Silicon, or `cpu` if you have no GPU (very slow — not recommended for a real run).
- `--wandb.enable=true` is optional. If you use it, run `wandb login` first. Otherwise drop the flag.
- Checkpoints land in `outputs/train/act_so101_quickstart/checkpoints/`. The final model is also pushed to the Hub at the `--policy.repo_id` you specified.
- To resume from an interruption: `lerobot-train --config_path=outputs/train/act_so101_quickstart/checkpoints/last/pretrained_model/train_config.json --resume=true`.
> [!TIP]
> **No GPU locally?** Train on Google Colab using the [ACT notebook](./notebooks#training-act), or rent a GPU via [Hugging Face Jobs](./il_robots#train-using-hugging-face-jobs) — pay-as-you-go, no setup.
For why ACT is the default and when to switch to SmolVLA, Pi0, or another policy, see [Choosing a policy](./policies_overview).
---
## Step 7 — Run your policy on the robot
Deploy with `lerobot-rollout`. **Use the same camera layout you used while recording** — keys and resolutions must match.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/act_so101_quickstart \
--robot.type=so101_follower \
--robot.port=<FOLLOWER_PORT> \
--robot.id=my_follower \
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30} }" \
--task="Put the red brick in the bowl" \
--duration=60
```
`--duration` is in seconds — leave it off to run until you stop the script. You should see the follower arm move on its own, attempting the task.
If observations from the robot use different keys than the policy expects, you'll need a [rename map](./rename_map). If latency matters, look at [async inference](./async) and [real-time chunking](./rtc).
---
## You're done 🎉
You now have a working IL pipeline end-to-end. From here, the natural next steps are:
- **Improve the policy** — record more diverse episodes, train longer, or try a stronger model. See [Choosing a policy](./policies_overview).
- **Go deeper on imitation learning** — [Imitation learning end-to-end](./il_robots) covers multi-camera setups, multi-task datasets, episode replay, evaluation, and Hugging Face Jobs.
- **Try RL with a human in the loop** — [HIL-SERL](./hilserl) trains a policy that improves while you correct it.
- **Use a different robot** — see [Supported robots](./so101) for low-cost arms, mobile platforms, bimanual, and humanoid.
- **Build something new** — [Bring your own hardware](./integrate_hardware) and [Add a new policy](./bring_your_own_policies).
Stuck on something? Check [Troubleshooting & FAQ](./troubleshooting), or ask on [Discord](https://discord.gg/s3KuuzsPFb).

186
docs/source/rebot_b601.mdx Normal file
View File

@@ -0,0 +1,186 @@
# reBot B601-DM
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
<div style="display: flex; align-items: center; gap: 10px;">
<img
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
alt="reBot B601-DM follower arm at its zero position"
width="48%"
/>
<img
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
alt="reBot Arm 102 leader arm at its zero position"
width="48%"
/>
</div>
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
## Install LeRobot 🤗
Follow our [Installation Guide](./installation), then install the reBot support:
```bash
pip install -e ".[rebot]"
```
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
## Registered device types
| Type | Kind |
| ------------------------ | -------------------------------------------- |
| `rebot_b601_follower` | single-arm B601-DM follower robot |
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
The bimanual types compose two single-arm instances and namespace each arm's
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
## Find the USB ports
For each device, find the USB port associated with its motor bus using:
```bash
lerobot-find-port
```
<Tip warning={true}>
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
leader's USB serial port. You may also need to grant access to the serial
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
</Tip>
## Calibration
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
### Follower (B601-DM)
<hfoptions id="calibrate-follower">
<hfoption id="Single arm">
```bash
lerobot-calibrate \
--robot.type=rebot_b601_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=follower \
--robot.can_adapter=damiao
```
</hfoption>
<hfoption id="Dual arm">
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
```bash
lerobot-calibrate \
--robot.type=bi_rebot_b601_follower \
--robot.id=bi_follower \
--robot.left_arm_config.port=/dev/ttyACM0 \
--robot.left_arm_config.can_adapter=damiao \
--robot.right_arm_config.port=/dev/ttyACM1 \
--robot.right_arm_config.can_adapter=damiao
```
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
</hfoption>
</hfoptions>
### Leader (reBot Arm 102)
<hfoptions id="calibrate-leader">
<hfoption id="Single arm">
```bash
lerobot-calibrate \
--teleop.type=rebot_102_leader \
--teleop.port=/dev/ttyUSB0 \
--teleop.id=leader
```
</hfoption>
<hfoption id="Dual arm">
```bash
lerobot-calibrate \
--teleop.type=bi_rebot_102_leader \
--teleop.id=bi_leader \
--teleop.left_arm_config.port=/dev/ttyUSB0 \
--teleop.right_arm_config.port=/dev/ttyUSB1
```
</hfoption>
</hfoptions>
## Teleoperation
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
<hfoptions id="teleoperate">
<hfoption id="Single arm">
```bash
lerobot-teleoperate \
--robot.type=rebot_b601_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=follower \
--robot.can_adapter=damiao \
--teleop.type=rebot_102_leader \
--teleop.port=/dev/ttyUSB0 \
--teleop.id=leader
```
</hfoption>
<hfoption id="Dual arm">
The bimanual leader and follower reuse the single-arm classes; each arm is
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
```bash
lerobot-teleoperate \
--robot.type=bi_rebot_b601_follower \
--robot.id=bi_follower \
--robot.left_arm_config.port=/dev/ttyACM0 \
--robot.left_arm_config.can_adapter=damiao \
--robot.right_arm_config.port=/dev/ttyACM1 \
--robot.right_arm_config.can_adapter=damiao \
--teleop.type=bi_rebot_102_leader \
--teleop.id=bi_leader \
--teleop.left_arm_config.port=/dev/ttyUSB0 \
--teleop.right_arm_config.port=/dev/ttyUSB1
```
</hfoption>
</hfoptions>
<Tip>
The leader and follower share the same joint names (`shoulder_pan,
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
leader actions map directly onto the follower.
</Tip>
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
```bash
lerobot-teleoperate \
--robot.type=rebot_b601_follower \
--robot.port=/dev/ttyACM0 \
--robot.can_adapter=damiao \
--teleop.type=rebot_102_leader \
--teleop.port=/dev/ttyUSB0 \
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
```
## Recording datasets
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).

View File

@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing:
```bash
lerobot-record \
lerobot-rollout \
--strategy.type=base \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
# <- RTC optional, use when running on low power hardware \
# --inference.type=rtc \
# --inference.rtc.execution_horizon=10 \
# --inference.rtc.max_guidance_weight=10.0 \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_red_leader_arm \
# --display_data=true #optional use if you want to see the camera stream \
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
```

210
docs/source/tools.mdx Normal file
View File

@@ -0,0 +1,210 @@
# Tools
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
emit structured invocations like `say(text="OK, starting now")` that the
runtime dispatches to a real implementation (TTS, controller, logger, …).
This page covers:
1. Where the tool catalog lives.
2. How the annotation pipeline produces tool-call atoms.
3. How to add your own tool.
## Where tools are declared
Two layers.
**The catalog** — a list of OpenAI-style function schemas — lives at
`meta/info.json["tools"]` on each dataset. Example:
```json
{
"features": { "...": "..." },
"tools": [
{
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak."
}
},
"required": ["text"]
}
}
}
]
}
```
Read it via the dataset metadata accessor:
```python
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
tools = meta.tools # list[dict] — OpenAI tool schemas
```
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
single-entry list with the canonical `say` schema. So unannotated
datasets and chat-template consumers keep working without any
configuration:
```python
prompt_str = tokenizer.apply_chat_template(
sample["messages"],
tools=meta.tools, # works either way
add_generation_prompt=False,
tokenize=False,
)
```
**The implementations** — runnable Python — will live under
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
not part of the catalog layer described here; today this layer ships
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
## Per-row tool _invocations_
The catalog above describes _what can be called_. The actual _call_ — the
function name plus the argument values — is stored per-row, on the
assistant atoms in `language_events`:
```python
{
"role": "assistant",
"content": null,
"style": null,
"timestamp": 12.4,
"camera": null,
"tool_calls": [
{ "type": "function",
"function": { "name": "say", "arguments": { "text": "On it." } } }
]
}
```
Recipes splice these into rendered messages via `tool_calls_from`:
```yaml
user_interjection_response:
bindings:
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- { role: user, content: "${task}", stream: high_level }
- {
role: assistant,
content: "${current_plan}",
stream: high_level,
target: true,
tool_calls_from: speech,
}
```
The model's training target is one assistant turn that carries both the
plan text _and_ the `say` tool call. At inference, the runtime parses
the generated text back into structured `tool_calls` and dispatches to
the matching implementation.
## How to add your own tool
> **Note:** Steps 2 and 3 below describe the runtime layer
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
> `get_tools(meta)`) which is not part of the catalog layer shipped
> today — those modules don't yet exist in the tree. Step 1 alone is
> enough to make the tool visible to the chat template via
> `meta.tools` so the model can learn to _generate_ the call;
> executing the call at inference requires the runtime layer.
Three steps. Concrete example: a `record_observation` tool the policy
can call to capture an extra observation outside the regular control
loop.
### Step 1 — declare the schema
Add an entry under `meta/info.json["tools"]`. Either edit the file
directly on disk _before_ running the annotation pipeline (it'll be
preserved) or hand it to `lerobot-annotate` via a config flag.
```json
{
"tools": [
{ "type": "function", "function": { "name": "say", "...": "..." } },
{
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a high-resolution still image for the user.",
"parameters": {
"type": "object",
"properties": {
"label": {
"type": "string",
"description": "Short label for the saved image."
}
},
"required": ["label"]
}
}
}
]
}
```
The schema follows OpenAI's function-calling convention exactly, so the
chat template can render it natively.
### Step 2 — implement the call
Create `src/lerobot/tools/record_observation.py`:
```python
from .base import Tool
from typing import Any
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
class RecordObservationTool:
name = "record_observation"
schema = RECORD_OBSERVATION_SCHEMA
def __init__(self, schema: dict | None = None, output_dir: str = "."):
self.output_dir = output_dir
def call(self, arguments: dict) -> str:
label = arguments["label"]
# ... save the latest camera frame to <output_dir>/<label>.png ...
return f"saved {label}.png"
```
One file per tool keeps dependencies isolated — `record_observation`
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
only the tools they need avoid heavy transitive deps.
### Step 3 — register it
Add to `src/lerobot/tools/registry.py`:
```python
from .record_observation import RecordObservationTool
TOOL_REGISTRY["record_observation"] = RecordObservationTool
```
That's it. At runtime `get_tools(meta)` looks up each schema in
`meta.tools`, instantiates the matching registered class, and returns
a name → instance dict the dispatcher can route into.
If you want to use a tool _without_ writing an implementation (e.g. for
training-time chat-template formatting only), step 1 alone is enough —
the model still learns to _generate_ the call. Steps 2 and 3 are only
needed to actually _execute_ it at inference.

View File

@@ -82,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
"video.pix_fmt": "yuv420p",
"video.fps": 30,
"video.channels": 3,
"is_depth_map": false,
"video.is_depth_map": false,
"video.g": 2,
"video.crf": 30,
"video.preset": "fast",
@@ -97,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
Two sources contribute to the `info` block:
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
<Tip>

View File

@@ -15,10 +15,12 @@
# limitations under the License.
"""
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage:
python examples/dataset/create_progress_videos.py \
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
def download_episode_metadata(
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"],
allow_patterns=["meta/**", progress_file],
ignore_patterns=["*.mp4"],
)
)
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
def load_progress_data(
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / "sarm_progress.parquet"
parquet_path = local_path / progress_file
if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found")
logging.warning("%s not found", progress_file)
return None
df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
logging.info(" %s columns: %s", progress_file, list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode)
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
return None
episode_df = episode_df.sort_values("frame_index")
@@ -576,6 +585,7 @@ def process_dataset(
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
@@ -585,6 +595,8 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Path to the final output file, or None on failure.
@@ -592,7 +604,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode)
local_path = download_episode_metadata(repo_id, episode, progress_file)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -600,9 +612,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode)
progress_data = load_progress_data(local_path, episode, progress_file)
if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.")
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
return None
logging.info(" Progress frames: %d", len(progress_data))
@@ -627,7 +639,7 @@ def process_dataset(
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
@@ -658,6 +670,15 @@ def main() -> None:
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -670,6 +691,7 @@ def main() -> None:
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
progress_file=args.progress_file,
)
if result:

View File

@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.0.0,<5.0.0",
"datasets>=4.7.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
@@ -138,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
# Common
av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
@@ -151,6 +153,8 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
@@ -174,6 +178,9 @@ unitree_g1 = [
"lerobot[pygame-dep]",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
# leader (motorbridge-smart-servo / FashionStar UART servos).
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
@@ -260,6 +267,7 @@ all = [
"lerobot[lekiwi]",
"lerobot[openarms]",
"lerobot[reachy2]",
"lerobot[rebot]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[diffusion]",

View File

@@ -199,12 +199,13 @@ class OpenCVCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
"""
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
self._validate_fourcc()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
set_fourcc_after_size_and_fps = platform.system() == "Windows"
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
self._validate_fourcc()
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -222,6 +223,11 @@ class OpenCVCamera(Camera):
else:
self._validate_fps()
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
# Set FOURCC last to make sure the requested pixel format is actually enforced.
self._validate_fourcc()
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""
@@ -430,7 +436,7 @@ class OpenCVCamera(Camera):
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame (blocking call)
1. Reads a color frame
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
@@ -439,9 +445,8 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0
while not stop_event.is_set():
while not self.stop_event.is_set():
try:
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
@@ -479,8 +484,6 @@ class OpenCVCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None

View File

@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
from the camera hardware via the RealSense pipeline.
Returns:
np.ndarray: The depth map as a NumPy array (height, width, 1)
of type `np.uint16` (raw depth values in millimeters).
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
Raises:
DeviceNotConnectedError: If the camera is not connected.
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color/depth frame (blocking call with 10s timeout)
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -474,9 +474,8 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0
while not stop_event.is_set():
while not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
@@ -487,8 +486,6 @@ class RealSenseCamera(Camera):
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
processed_depth_frame = processed_depth_frame[..., np.newaxis]
capture_time = time.perf_counter()
@@ -525,8 +522,6 @@ class RealSenseCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive(): # pragma: no cover
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None
@@ -537,6 +532,7 @@ class RealSenseCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
@@ -579,6 +575,7 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
@@ -614,71 +611,6 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
"""Read the latest depth frame asynchronously, in metric meters.
Mirrors :meth:`async_read` but returns the depth stream rather than the
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
the background read thread is not running.
TimeoutError: If no frame becomes available within ``timeout_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
self.new_frame_event.clear()
if depth_frame is None:
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
return depth_frame
@check_if_not_connected
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent depth frame in metric meters (peeking).
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
no depth frame has been captured yet.
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
timestamp = self.latest_timestamp
if depth_frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any depth frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return depth_frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.

View File

@@ -249,9 +249,8 @@ class ZMQCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
stop_event = self.stop_event
failure_count = 0
while not stop_event.is_set():
while not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
capture_time = time.perf_counter()
@@ -293,8 +292,6 @@ class ZMQCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None

View File

@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -34,10 +35,8 @@ from .types import (
from .video import (
VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS,
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
__all__ = [
@@ -51,14 +50,15 @@ __all__ = [
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
"VideoEncoderConfig",
"DepthEncoderConfig",
# Defaults
"camera_encoder_defaults",
"depth_encoder_defaults",
# Constants
"VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS",

View File

@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
from .video import VideoEncoderConfig, camera_encoder_defaults
@dataclass
@@ -60,8 +60,6 @@ class DatasetRecordConfig:
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False

View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
"""``${name}`` placeholder pattern used by both recipe binding-reference
discovery (here) and rendered-message substitution (in ``language_render``)."""
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
# ``stream`` is typed Optional only so the dataclass can keep its
# field ordering, but recipes must always tag every turn with a
# stream — the renderer's ``_validate_rendered`` would reject
# ``None`` later on. Fail at construction so the bad recipe is
# caught at YAML load time rather than at the first sample.
if self.stream is None:
raise ValueError(
f"MessageTurn(role={self.role!r}) is missing a stream — "
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
)
if self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)

View File

@@ -19,8 +19,8 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar
from dataclasses import dataclass, field
from typing import Any
from lerobot.utils.import_utils import require_package
@@ -36,12 +36,11 @@ HW_VIDEO_CODECS = [
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
)
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
@@ -53,19 +52,6 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
)
# Default depth quantization and encoding parameters.
DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@dataclass
class VideoEncoderConfig:
@@ -100,10 +86,6 @@ class VideoEncoderConfig:
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
# Source-data channel count this encoder is expected to handle (3 for RGB,
# 1 for depth, etc.)
_DEFAULT_CHANNELS: ClassVar[int] = 3
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
@@ -156,9 +138,7 @@ class VideoEncoderConfig:
require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
)
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
@@ -238,10 +218,6 @@ class VideoEncoderConfig:
elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "ffv1":
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
# are not meaningful.
set_if("threads", encoder_threads)
else:
set_if("crf", self.crf)
set_if("preset", self.preset)
@@ -257,59 +233,3 @@ class VideoEncoderConfig:
def camera_encoder_defaults() -> VideoEncoderConfig:
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
return VideoEncoderConfig()
@dataclass
class DepthEncoderConfig(VideoEncoderConfig):
"""Encoder configuration for depth-map streams.
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``…) and adds the four parameters of the depth
quantizer.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"gray12le"``.
Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented
by quantum ``0``.
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
shift: Pre-log offset for numerical stability near zero.
use_log: ``True`` for logarithmic quantization (default; matches
sensor error profile), ``False`` for linear.
"""
vcodec: str = "hevc"
pix_fmt: str = "gray12le"
depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG
_DEFAULT_CHANNELS: ClassVar[int] = 1
@classmethod
def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
"""Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
codec/tuning fields and then layers the depth-specific tuning
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
Missing keys fall back to the class defaults.
"""
base = VideoEncoderConfig.from_video_info(video_info)
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
video_info = video_info or {}
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{name}")
if value is not None:
kwargs[name] = value
return cls(**kwargs)
def depth_encoder_defaults() -> DepthEncoderConfig:
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
return DepthEncoderConfig()

View File

@@ -31,12 +31,21 @@ from .dataset_tools import (
modify_features,
modify_tasks,
recompute_stats,
reencode_dataset,
remove_feature,
split_dataset,
)
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
@@ -54,10 +63,15 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"check_video_encoder_parameters_pyav",
@@ -69,6 +83,7 @@ __all__ = [
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",
@@ -77,6 +92,7 @@ __all__ = [
"modify_features",
"modify_tasks",
"recompute_stats",
"reencode_dataset",
"remove_feature",
"resolve_delta_timestamps",
"safe_stop_image_writer",

View File

@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
if features[key]["dtype"] in {"string", "language"}:
continue
if features[key]["dtype"] in ["image", "video"]:
@@ -550,10 +550,8 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
if key == "count" and value.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
raise ValueError(
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
)
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):

View File

@@ -36,12 +36,12 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_stats,
write_tasks,
)
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
from .utils import (
DEFAULT_EPISODES_PATH,
check_version_compatibility,
@@ -177,7 +177,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -338,30 +337,54 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def depth_keys(self) -> list[str]:
"""Keys to access depth-map modalities stored as videos or images.
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
"""
def _is_depth(ft: dict) -> bool:
info = ft.get("info") or {}
video_info = ft.get("video_info") or {}
return (
info.get("is_depth_map", False)
or info.get("video.is_depth_map", False)
or video_info.get("video.is_depth_map", False)
)
return [key for key, ft in self.features.items() if _is_depth(ft)]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def has_language_columns(self) -> bool:
"""Return ``True`` if the dataset declares any language column.
Used to gate language-aware code paths (collate, render step) so
unannotated datasets keep PyTorch's default collate behavior.
"""
return any(col in self.features for col in LANGUAGE_COLUMNS)
@property
def tools(self) -> list[dict]:
"""OpenAI-style tool schemas declared by this dataset.
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
can mutate the result safely. Falls back to
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
``say`` schema) when the dataset doesn't declare any — that way
unannotated datasets and chat-template consumers
(``apply_chat_template(messages, tools=meta.tools)``) keep
working out of the box.
Implementations live under :mod:`lerobot.tools` (one file per
tool); see ``docs/source/tools.mdx`` for the authoring guide.
"""
declared = self.info.tools
if declared:
return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS]
@tools.setter
def tools(self, value: list[dict] | None) -> None:
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
Writes ``value`` into the on-disk ``info.json`` (or clears the
``tools`` key when ``value`` is ``None`` or empty), then reloads
``self.info`` so the in-memory metadata matches what's on disk.
Saves callers from hand-editing ``info.json`` and re-instantiating
the metadata object.
"""
self.info.tools = [dict(t) for t in value] if value else None
write_info(self.info, self.root)
self.info = load_info(self.root)
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -557,7 +580,7 @@ class LeRobotDatasetMetadata:
def update_video_info(
self,
video_key: str | None = None,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
) -> None:
"""Populate per-feature video info in ``info.json``.
@@ -577,13 +600,9 @@ class LeRobotDatasetMetadata:
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
existing = self.features[key].get("info") or {}
# Skip only if real video info has already been written. The ``is_depth_map`` entry (created at feature creation) is not blocking.
if set(existing.keys()) - {"is_depth_map"}:
continue
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
new_info = get_video_info(video_path, video_encoder=video_encoder)
self.info.features[key]["info"] = {**existing, **new_info}
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
def update_chunk_settings(
self,
@@ -694,7 +713,6 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(

View File

@@ -22,10 +22,7 @@ from pathlib import Path
import datasets
import torch
from lerobot.configs.video import DepthEncoderConfig
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
@@ -89,12 +86,6 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
@@ -256,18 +247,7 @@ class DatasetReader:
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
is_depth=vid_key in self._meta.depth_keys,
)
if vid_key in self._meta.depth_keys:
depth_encoder = self._depth_encoder_configs[vid_key]
frames = dequantize_depth(
frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_tensor=True,
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())
@@ -315,9 +295,4 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item

View File

@@ -26,7 +26,7 @@ This module provides utilities for:
import logging
import shutil
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from pathlib import Path
import datasets
@@ -61,11 +61,13 @@ from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
VIDEO_DIR,
update_chunk_file_indices,
)
from .video_utils import (
encode_video_frames,
get_video_info,
reencode_video,
)
@@ -1329,7 +1331,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
video_encoder=camera_encoder,
camera_encoder=camera_encoder,
overwrite=True,
)
@@ -1813,7 +1815,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
video_encoder=camera_encoder,
camera_encoder=camera_encoder,
overwrite=True,
)
@@ -1860,7 +1862,7 @@ def convert_image_to_video_dataset(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info.features[img_key]["info"] = get_video_info(
video_path, video_encoder=camera_encoder
video_path, camera_encoder=camera_encoder
)
write_info(new_meta.info, new_meta.root)
@@ -1884,3 +1886,83 @@ def convert_image_to_video_dataset(
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def _reencode_video_worker(args: tuple) -> Path:
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
video_path, camera_encoder, encoder_threads = args
reencode_video(
input_video_path=video_path,
output_video_path=video_path,
camera_encoder=camera_encoder,
encoder_threads=encoder_threads,
overwrite=True,
)
return video_path
def reencode_dataset(
dataset: LeRobotDataset,
camera_encoder: VideoEncoderConfig,
encoder_threads: int | None = None,
num_workers: int | None = None,
) -> LeRobotDataset:
"""Re-encode every video in a dataset with a new set of encoding parameters.
Videos are re-encoded in-place and the video information in ``info.json`` is refreshed.
Args:
dataset: An existing :class:`LeRobotDataset` whose videos will be
re-encoded.
camera_encoder: Target encoder configuration applied to every video
file.
encoder_threads: Per-encoder thread count forwarded to
:func:`reencode_video`. ``None`` lets the codec decide.
num_workers: Number of parallel processes. ``None`` or ``0`` means
sequential (no multiprocessing); ``1+`` spawns a
:class:`~concurrent.futures.ProcessPoolExecutor`.
Returns:
The same :class:`LeRobotDataset` instance with its metadata updated
on disk.
"""
meta = dataset.meta
video_paths_list = []
# Only re-encode if the videos are not already encoded with the given video encoding parameters
for video_key in meta.video_keys:
current_info = meta.info.features[video_key].get("info", {})
current_encoder = VideoEncoderConfig.from_video_info(current_info)
if current_encoder != camera_encoder:
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
else:
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
if len(video_paths_list) == 0:
logging.warning("Dataset has no videos to re-encode.")
return dataset
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
if num_workers and num_workers > 1:
with ProcessPoolExecutor(max_workers=num_workers) as pool:
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Re-encoding videos",
):
future.result()
else:
for args in tqdm(worker_args, desc="Re-encoding videos"):
_reencode_video_worker(args)
# Refresh video info in metadata for every video key.
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(0, vid_key)
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
write_info(meta.info, meta.root)
logging.info("Dataset metadata updated.")
return dataset

View File

@@ -31,12 +31,7 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata
@@ -53,7 +48,6 @@ from .io_utils import (
write_info,
)
from .utils import (
DEFAULT_DEPTH_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
@@ -73,22 +67,17 @@ def _encode_video_worker(
episode_index: int,
root: Path,
fps: int,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
path_template = (
DEFAULT_DEPTH_PATH
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
else DEFAULT_IMAGE_PATH
)
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir,
temp_path,
fps,
video_encoder=video_encoder,
camera_encoder=camera_encoder,
encoder_threads=encoder_threads,
overwrite=True,
)
@@ -108,7 +97,6 @@ class DatasetWriter:
meta: LeRobotDatasetMetadata,
root: Path,
camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
@@ -122,8 +110,6 @@ class DatasetWriter:
root: Local dataset root directory.
camera_encoder: Video encoder settings applied to all cameras.
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
depth_encoder: Video encoder settings applied to all **depth** cameras.
``None`` uses :func:`~lerobot.configs.depth_encoder_defaults`.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
batch_encoding_size: Number of episodes to accumulate before
@@ -135,7 +121,6 @@ class DatasetWriter:
self._meta = meta
self._root = root
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
@@ -160,8 +145,7 @@ class DatasetWriter:
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
fpath = path_template.format(
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
@@ -211,7 +195,6 @@ class DatasetWriter:
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self._meta.video_keys),
depth_video_keys=set(self._meta.video_keys) & set(self._meta.depth_keys),
temp_dir=self._root,
)
@@ -267,7 +250,14 @@ class DatasetWriter:
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
stacked_values = np.stack(episode_buffer[key])
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
stacked_values = stacked_values.reshape(episode_length)
episode_buffer[key] = stacked_values
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
@@ -310,9 +300,7 @@ class DatasetWriter:
episode_index,
self._root,
self._meta.fps,
self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
self._camera_encoder,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
@@ -523,12 +511,7 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(
video_key,
video_encoder=self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
)
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
write_info(self._meta.info, self._meta.root)
metadata = {
@@ -595,14 +578,13 @@ class DatasetWriter:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
is_depth = video_key in self._meta.depth_keys
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker(
video_key,
episode_index,
self._root,
self._meta.fps,
self._depth_encoder if is_depth else self._camera_encoder,
self._camera_encoder,
self._encoder_threads,
)

View File

@@ -1,214 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
"""
import math
from typing import Literal
import av
import numpy as np
import torch
from numpy.typing import NDArray
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
DEPTH_QMAX,
)
from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0
_UINT16_MAX = 65535
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
if depth_min + shift <= 0:
raise ValueError(
f"depth_min + shift must be positive for logarithmic quantization, "
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
)
def _depth_input_to_float32_and_unit(
depth: NDArray[np.integer] | NDArray[np.floating],
input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = (
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
)
return depth.astype(np.float32, order="K"), resolved_unit
def quantize_depth(
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
Depth maps are packed into 12-bit integer frames so they fit in standard
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
Logarithmic quantization is the default because it allocates more quanta
to near-range depth, which matches the (1/depth) error profile of typical
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
**Input units**:
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
- ``input_unit="mm"``: interpret input values as millimetres.
- ``input_unit="m"``: interpret input values as metres.
Quantization math runs in the **resolved input unit**.
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
Args:
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
depth_min: Depth (metres) at quantum ``0``.
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
use_log: If ``True`` (default), quantize in log space.
video_backend: Video backend to use for encoding. Defaults to "pyav".
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
Returns:
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
``[0, DEPTH_QMAX]``.
Raises:
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
depth = depth.squeeze()
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit.
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_u + shift_u))
log_max = math.log(float(depth_max_u + shift_u))
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
else:
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
write_u16_plane(frame.planes[0], quantized)
return frame
else:
return quantized
def dequantize_depth(
quantized: NDArray[np.uint16] | av.VideoFrame,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
"""Inverse of :func:`quantize_depth`.
Tuning arguments **must match** :func:`quantize_depth`.
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
the requested output unit.
Args:
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
``[depth_min, depth_max]``.
output_tensor: If True, return a torch.Tensor instead of a numpy array.
Returns:
Depth map in the requested unit and dtype.
Raises:
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=pix_fmt)
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
depth_min_m = np.float32(depth_min)
depth_max_m = np.float32(depth_max)
shift_m = np.float32(shift)
# The de-normalization and de-quantization is performed in meters (convenience choice).
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_m + shift_m))
log_max = math.log(float(depth_max_m + shift_m))
depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
else:
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
# Add single-channel dim: (H, W) → (H, W, 1)
if depth_m.ndim == 2:
depth_m = depth_m[..., np.newaxis]
# Return depth as float32 meters.
if output_unit == "m":
return torch.from_numpy(depth_m) if output_tensor else depth_m
# Return depth as uint16 millimeters.
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
if output_tensor:
# torch.uint16 support is very limited, we convert to float32 instead.
return torch.from_numpy(mm.astype(np.float32))
else:
return mm

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pprint import pformat
import datasets
@@ -23,6 +24,12 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -47,7 +54,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -278,6 +291,8 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return validate_feature_language(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
@@ -321,7 +336,7 @@ def validate_feature_image_or_video(
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
expected_shape (list[str]): The expected shape (C, H, W).
value: The image data to validate.
Returns:
@@ -357,6 +372,30 @@ def validate_feature_string(name: str, value: str) -> str:
return ""
def validate_feature_language(name: str, value) -> str:
"""Validate a feature that is expected to hold language annotations.
Language columns (``language_persistent`` / ``language_events``) are
populated after recording by the annotation pipeline, not at record time.
Any value supplied here is dropped before the frame is written, so a
non-empty value almost certainly signals a mistake. We warn rather than
fail to keep recording resilient.
Args:
name (str): The name of the feature.
value: The value to validate.
Returns:
str: Always an empty string — language values are non-fatal.
"""
if value is not None:
logging.warning(
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
f"not at record time. The provided value will be dropped."
)
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
"""Validate the episode buffer before it's written to disk.

View File

@@ -42,41 +42,10 @@ def safe_stop_image_writer(func):
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
Behaviour by shape:
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
The native dtype is preserved using the matching PIL mode
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
(existing behaviour, gated by ``range_check``).
Other shapes / channel counts raise ``NotImplementedError`` or
``ValueError``.
"""
# TODO(CarolinePascal): 4 dimensions RGB-D images
if image_array.ndim not in (2, 3):
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
if image_array.ndim == 2:
if image_array.dtype not in [np.uint16, np.float32]:
raise ValueError(
f"Unsupported single-channel image dtype: {image_array.dtype}. "
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
)
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
# 3D path: must be RGB (3 channels), channels-first or channels-last.
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
@@ -102,28 +71,13 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array)
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
"""
suffix = Path(fpath).suffix.lower()
if suffix == ".png":
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
"""
Saves a NumPy array or PIL Image to a file.
This function handles both NumPy arrays and PIL Image objects, converting
the former to a PIL Image before saving. It includes error handling for
the save operation. The output format is inferred from the *fpath*
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
→ lossless raw depth maps (TIFF).
the save operation.
Args:
image (np.ndarray | PIL.Image.Image): The image data to save.
@@ -147,7 +101,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
img.save(fpath, compress_level=compress_level)
except Exception as e:
logger.error("Error writing image %s: %s", fpath, e)

View File

@@ -31,10 +31,10 @@ from torchvision import transforms
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
from .language import LANGUAGE_COLUMNS
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
@@ -186,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -265,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in LANGUAGE_COLUMNS:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
elif first_item is None or isinstance(first_item, dict):
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -304,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
skip_keys = {"task", *LANGUAGE_COLUMNS}
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item

View File

@@ -0,0 +1,242 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {
"subtask",
"plan",
"memory",
"motion",
"interjection",
"vqa",
"trace",
"task_aug",
}
# Project-local styles can be registered at import time by appending to
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
# validation. Empty by default — populate from a downstream module that
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
# the new style's column.
EXTENDED_STYLES: set[str] = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
# is intentionally NOT in this set: motion primitives are described in
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
# view-dependent. The ``camera`` field nevertheless lives on
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
# behave symmetrically across the two columns; persistent rows simply
# always have ``camera=None`` in practice today.
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
uses for frame data.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float32(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float32"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def is_view_dependent_style(style: str | None) -> bool:
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
return style in VIEW_DEPENDENT_STYLES
def validate_camera_field(style: str | None, camera: str | None) -> None:
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
a non-view-dependent style carries one. Pipeline writers and the validator
should call this on every emitted row.
"""
if is_view_dependent_style(style):
if not camera:
raise ValueError(
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
f"field referencing an 'observation.images.*' feature key."
)
elif camera is not None:
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
# --- Tool registry --------------------------------------------------------
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
# of OpenAI-style function schemas. The runtime / training stack reads them
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
# fallback when the dataset doesn't declare any). Implementations live
# under :mod:`lerobot.tools` (one file per tool); see
# ``docs/source/tools.mdx`` for the authoring guide.
SAY_TOOL_SCHEMA: dict = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
"""Canonical schema for the ``say`` tool emitted by the steerable
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
writer, PR 3's runtime tool registry, and the dataset visualizer all
import this constant rather than duplicating the dict."""
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
chat-template consumers (``apply_chat_template(messages, tools=...)``)
keep working out of the box."""
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")

View File

@@ -0,0 +1,545 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
from lerobot.utils.utils import unwrap_scalar
from .language import LANGUAGE_PERSISTENT, column_for_style
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
``camera`` selector. Only valid for persistent styles.
"""
_validate_persistent_resolver("active_at", style)
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) <= t
]
if not matches:
return None
latest_ts = max(_timestamp(row) for row in matches)
return _select_one(
[row for row in matches if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
camera=camera,
)
EMITTED_AT_TOLERANCE_S = 0.1
"""Half-window for matching persistent rows to a frame timestamp in
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
is also a float32 from parquet, so in the ideal hot path an exact match
would suffice — but any caller that derives ``t`` arithmetically (e.g.
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
common arithmetic drift without admitting frames that are visibly far
apart at typical control rates (30100 Hz). This does mean two persistent
rows of the same selector emitted within 0.1 s of each other cannot be
told apart by ``emitted_at`` — acceptable because persistent annotations
(subtask / plan / memory transitions) change on a human-action timescale,
not at the camera frame rate."""
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
we use a tolerance instead of bit-equality). For event styles, the
``events`` list is assumed to come from the dataset row at frame ``t``
(event rows carry no timestamp of their own), so all matching event rows
are considered emitted at ``t``. ``camera`` filters by the row's
``camera`` field — required to disambiguate when multiple view-dependent
rows share ``(t, role)`` across cameras.
"""
if column_for_style(style) == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
]
else:
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions before the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions after the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(
task: str | None,
dataset_ctx: Any | None,
*,
persistent: Sequence[LanguageRow] = (),
sample_idx: int = 0,
) -> str | None:
"""Return the task string for ``sample_idx``.
Resolution order:
1. Explicit ``task`` override (caller-supplied) wins.
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
deterministically pick one by ``sample_idx`` so each frame of an
episode rotates through the available rephrasings across an epoch.
This realizes Xiao 2022 / CAST-style task-prompt diversity without
changing ``meta/tasks.parquet`` and without forcing recipes to opt
in: ``${task}`` automatically picks a rephrasing when one exists,
and falls back to the canonical task otherwise. Recipes that want
the literal canonical task can override the binding.
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
backed by ``meta/tasks.parquet``).
"""
if task is not None:
return task
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
return str(chosen)
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
if name == "emitted_at":
return emitted_at(t, persistent=persistent, events=events, **kwargs)
if name == "active_at":
return active_at(t, persistent=persistent, **kwargs)
if name == "nth_prev":
return nth_prev(t, persistent=persistent, **kwargs)
if name == "nth_next":
return nth_next(t, persistent=persistent, **kwargs)
raise ValueError(f"Unknown language resolver: {name!r}")
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
# ``stream`` is enforced non-None at MessageTurn construction time
# (see ``MessageTurn.__post_init__``), so a missing stream here would
# mean the dataclass invariant was bypassed; no need to re-check.
def _nth_relative(
name: str,
t: float,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(name, style)
if abs(offset) < 1:
raise ValueError(f"{name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
key=_row_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{name} requires a persistent style.")
if column_for_style(style) != LANGUAGE_PERSISTENT:
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
and (camera is None or row.get("camera") == camera)
]
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the resolver is ambiguous.
Multiple matches always raise — even when the caller already passed
some selectors — because remaining ambiguity means the data has
several rows that look identical to the resolver and the caller
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
rows shared across cameras).
"""
if not rows:
return None
if len(rows) > 1:
raise ValueError(
f"Ambiguous resolver for style={style!r} role={role!r} "
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
f"Add a selector that distinguishes them."
)
return rows[0]
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Stable sort key for both persistent and event rows.
Event rows lack ``timestamp`` (it is implicit in the frame), so default
to ``0.0`` — within a single frame all event rows share the same sort
bucket and are tiebroken by ``(style, role)``.
"""
timestamp = row.get("timestamp")
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
return (ts, row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
return float(unwrap_scalar(row["timestamp"]))
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized

View File

@@ -24,7 +24,7 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
from lerobot.configs import VideoEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
@@ -60,7 +60,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return_uint8: bool = False,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -187,9 +186,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
is used by the writer.
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults`
is used by the writer.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
@@ -277,7 +273,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = self._build_streaming_encoder(
self.meta.fps,
camera_encoder,
depth_encoder,
encoder_queue_maxsize,
encoder_threads,
)
@@ -285,7 +280,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
meta=self.meta,
root=self.root,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -328,14 +322,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _build_streaming_encoder(
fps: int,
camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_queue_maxsize: int,
encoder_threads: int | None,
) -> StreamingVideoEncoder:
return StreamingVideoEncoder(
fps=fps,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
@@ -653,7 +645,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -686,8 +677,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch-encoding videos. ``1`` means encode immediately.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
metadata_buffer_size: Number of episode metadata records to buffer
@@ -731,13 +720,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -761,7 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
@@ -791,8 +778,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch-encoding videos.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
image_writer_processes: Subprocesses for async image writing.
@@ -839,13 +824,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,

View File

@@ -24,7 +24,6 @@ import logging
from typing import Any
import av
import numpy as np
logger = logging.getLogger(__name__)
@@ -32,22 +31,6 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
@functools.cache
def get_pix_fmt_channels(pix_fmt: str) -> int:
"""Return the number of components (channels) for *pix_fmt*."""
return len(av.VideoFormat(pix_fmt).components)
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
@@ -159,16 +142,6 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
)
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
"""Ensure *pix_fmt* can carry at least *channels* components."""
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
if pix_fmt_channels < channels:
raise ValueError(
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
f"but the source data has {channels} channel(s)."
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
@@ -183,18 +156,12 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
_check_option_value(vcodec, key, value, supported_options[key])
def check_video_encoder_parameters_pyav(
vcodec: str,
pix_fmt: str,
codec_options: dict[str, Any],
channels: int | None = None,
) -> None:
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
@@ -204,6 +171,4 @@ def check_video_encoder_parameters_pyav(
if not options:
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
_check_pixel_format(vcodec, pix_fmt)
if channels is not None:
_check_pix_fmt_channels(pix_fmt, channels)
_check_codec_options(vcodec, codec_options)

View File

@@ -88,12 +88,10 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
@@ -131,6 +129,9 @@ class DatasetInfo:
# Optional metadata
robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict)
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
tools: list[dict] | None = None
def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation
@@ -152,11 +153,15 @@ class DatasetInfo:
"""Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them.
Drops ``tools`` when unset so existing datasets keep a clean
``info.json``.
"""
d = dataclasses.asdict(self)
for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"])
if d.get("tools") is None:
d.pop("tools", None)
return d
@classmethod

View File

@@ -17,11 +17,13 @@ import contextlib
import glob
import importlib
import logging
import os
import queue
import shutil
import tempfile
import threading
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
from fractions import Fraction
from pathlib import Path
@@ -37,16 +39,11 @@ from datasets.features.features import register_feature
from PIL import Image
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__)
@@ -56,7 +53,6 @@ def decode_video_frames(
tolerance_s: float,
backend: str | None = None,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
@@ -76,11 +72,6 @@ def decode_video_frames(
Currently supports torchcodec on cpu and pyav.
"""
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
if backend is None:
backend = get_safe_default_video_backend()
if backend == "torchcodec":
@@ -100,7 +91,6 @@ def decode_video_frames_pyav(
tolerance_s: float,
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video using PyAV.
@@ -150,13 +140,9 @@ def decode_video_frames_pyav(
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
if is_depth:
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
# Convert to CHW uint8 to match torchcodec's output layout.
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
# Convert to CHW uint8 to match torchcodec's output layout.
arr = frame.to_ndarray(format="rgb24") # H, W, 3
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
@@ -207,15 +193,70 @@ def decode_video_frames_pyav(
return closest_frames
class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
DEFAULT_DECODER_CACHE_SIZE = 100
"""Default LRU capacity for :class:`VideoDecoderCache`.
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
Sized to comfortably hold a small rolling window of episodes worth of decoders
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
to the constructor (``None`` restores the legacy unbounded behaviour).
"""
def _default_max_cache_size() -> int | None:
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
if raw is None:
return DEFAULT_DECODER_CACHE_SIZE
raw = raw.strip().lower()
if raw in ("", "none", "unbounded", "-1"):
return None
try:
value = int(raw)
except ValueError as e:
raise ValueError(
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
) from e
if value <= 0:
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
return value
class VideoDecoderCache:
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
backing it. When the cache is full and a new path is requested, the
least-recently-used entry is evicted and its file handle is closed. This
bounds host-RAM growth when iterating over datasets with many distinct
video files (otherwise each ``DataLoader`` worker pins every decoder it has
ever opened until the process exits).
Args:
max_size: Maximum number of decoders to retain. ``None`` disables
eviction and restores legacy unbounded behaviour. Defaults to the
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
:data:`DEFAULT_DECODER_CACHE_SIZE`.
"""
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
raise ValueError(f"max_size must be positive or None; got {max_size}")
self.max_size: int | None = max_size # type: ignore[assignment]
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
self._lock = Lock()
def __contains__(self, video_path: object) -> bool:
with self._lock:
return str(video_path) in self._cache
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
@@ -227,22 +268,36 @@ class VideoDecoderCache:
video_path = str(video_path)
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
entry = self._cache.get(video_path)
if entry is not None:
self._cache.move_to_end(video_path)
return entry[0]
return self._cache[video_path][0]
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
# Evict LRU entries until we are back under the cap. We close
# evicted file handles immediately; the associated ``VideoDecoder``
# is released to the GC when its last reference goes away.
if self.max_size is not None:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
with contextlib.suppress(Exception):
evicted_handle.close()
return decoder
def clear(self):
"""Clear the cache and close file handles."""
"""Clear the cache and close all file handles."""
with self._lock:
for _, file_handle in self._cache.values():
file_handle.close()
with contextlib.suppress(Exception):
file_handle.close()
self._cache.clear()
def size(self) -> int:
@@ -351,17 +406,17 @@ def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
*,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
if video_encoder is None:
video_encoder = camera_encoder_defaults()
vcodec = video_encoder.vcodec
pix_fmt = video_encoder.pix_fmt
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
vcodec = camera_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -373,8 +428,7 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True)
# Get input frames
suffix = ".png" if not isinstance(video_encoder, DepthEncoderConfig) else ".tiff"
template = "frame-" + ("[0-9]" * 6) + suffix
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
@@ -384,7 +438,7 @@ def encode_video_frames(
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
# Set logging level
if log_level is not None:
@@ -420,6 +474,92 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
def reencode_video(
input_video_path: Path | str,
output_video_path: Path | str,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
) -> None:
"""Re-encode a video file using the given encoder configuration.
Args:
input_video_path: Existing video file to read.
output_video_path: Path for the re-encoded file.
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
"""
camera_encoder = camera_encoder or camera_encoder_defaults()
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logger.warning(f"Video file already exists: {output_video_path}. Skipping re-encode.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
vcodec = camera_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
if log_level is not None:
logging.getLogger("libav").setLevel(log_level)
try:
with av.open(input_video_path, mode="r") as src:
try:
in_stream = src.streams.video[0]
except IndexError as e:
raise ValueError(f"No video stream in {input_video_path}") from e
fps = (
in_stream.base_rate
) # We allow fractional fps though LeRobotDataset only supports integer fps
width = int(in_stream.width)
height = int(in_stream.height)
with av.open(
tmp_output_video_path,
mode="w",
options={
"movflags": "faststart"
}, # faststart is to move the metadata to the beginning of the file to speed up loading
) as dst:
out_stream = dst.add_stream(vcodec, fps, options=video_options)
out_stream.pix_fmt = pix_fmt
out_stream.width = width
out_stream.height = height
for frame in src.decode(in_stream):
frame = frame.reformat(width=width, height=height, format=pix_fmt)
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
packet = out_stream.encode()
if packet:
dst.mux(packet)
shutil.move(tmp_output_video_path, output_video_path)
except Exception:
Path(tmp_output_video_path).unlink(missing_ok=True)
raise
finally:
if log_level is not None:
av.logging.restore_default_callback()
if not output_video_path.exists():
raise OSError(f"Video re-encoding did not work. File not found: {output_video_path}.")
def concatenate_video_files(
input_video_paths: list[Path | str],
output_video_path: Path,
@@ -536,21 +676,22 @@ class _CameraEncoderThread(threading.Thread):
self,
video_path: Path,
fps: int,
video_encoder: VideoEncoderConfig,
vcodec: str,
pix_fmt: str,
codec_options: dict[str, str],
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
encoder_threads: int | None = None,
):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.video_encoder = video_encoder
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.codec_options = codec_options
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
@@ -575,12 +716,12 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close
break
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
# Ensure HWC uint8 numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if not self.is_depth and frame_data.dtype != np.uint8:
if frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
@@ -588,29 +729,15 @@ class _CameraEncoderThread(threading.Thread):
height, width = frame_data.shape[:2]
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(
self.video_encoder.vcodec,
self.fps,
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
)
output_stream.pix_fmt = self.video_encoder.pix_fmt
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
if not self.is_depth:
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
else:
video_frame = quantize_depth(
frame_data,
depth_min=self.video_encoder.depth_min,
depth_max=self.video_encoder.depth_max,
shift=self.video_encoder.shift,
use_log=self.video_encoder.use_log,
video_backend=self.video_encoder.video_backend,
)
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
@@ -669,7 +796,6 @@ class StreamingVideoEncoder:
self,
fps: int,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
@@ -685,7 +811,6 @@ class StreamingVideoEncoder:
"""
self.fps = fps
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads
self.queue_maxsize = queue_maxsize
@@ -698,25 +823,18 @@ class StreamingVideoEncoder:
self._episode_active = False
self._closed = False
def start_episode(
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
) -> None:
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
"""Start encoder threads for a new episode.
Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
temp_dir: Base directory for temporary MP4 files
depth_video_keys: List of video feature keys that carry depth maps (e.g.
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
"""
if self._episode_active:
self.cancel_episode()
self._dropped_frames.clear()
if depth_video_keys is None:
depth_video_keys = []
for video_key in video_keys:
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
result_queue: queue.Queue = queue.Queue(maxsize=1)
@@ -725,15 +843,17 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
vcodec = self._camera_encoder.vcodec
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
video_encoder=encoder,
vcodec=vcodec,
pix_fmt=self._camera_encoder.pix_fmt,
codec_options=codec_options,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self._encoder_threads,
)
encoder_thread.start()
@@ -940,13 +1060,13 @@ def get_audio_info(video_path: Path | str) -> dict:
def get_video_info(
video_path: Path | str,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
video_encoder: If provided, record the exact encoder settings used to encode this
camera_encoder: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence — encoder fields are only written for keys
not already populated from the video file itself.
"""
@@ -966,10 +1086,13 @@ def get_video_info(
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
# Reset logging level
av.logging.restore_default_callback()
@@ -978,18 +1101,27 @@ def get_video_info(
video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided
if video_encoder is not None:
for field_name, field_value in asdict(video_encoder).items():
if camera_encoder is not None:
for field_name, field_value in asdict(camera_encoder).items():
# vcodec is already populated from the video stream
if field_name == "vcodec":
continue
video_info.setdefault(f"video.{field_name}", field_value)
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.

View File

@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import _placo_available, require_package
from lerobot.utils.import_utils import require_package
if TYPE_CHECKING or _placo_available:
_placo_runtime_error: ImportError | None = None
if TYPE_CHECKING:
import placo # type: ignore[import-not-found]
else:
placo = None
try:
import placo # type: ignore[import-not-found]
except ImportError as _placo_import_err:
placo = None
_placo_runtime_error = _placo_import_err
def _raise_if_placo_unusable() -> None:
if placo is None and _placo_runtime_error is not None:
raise ImportError(
f"placo is installed but failed to import: {_placo_runtime_error!s}"
) from _placo_runtime_error
class RobotKinematics:
@@ -44,6 +57,7 @@ class RobotKinematics:
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
require_package("placo", extra="placo-dep")
_raise_if_placo_unusable()
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)

View File

@@ -43,6 +43,7 @@ from .tables import (
CAN_CMD_SET_ZERO,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
HANDSHAKE_TIMEOUT_S,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
def _query_status_via_clear_fault(
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
) -> tuple[bool, can.Message | None]:
motor_name = self._get_motor_name(motor)
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
def _recv_status_via_clear_fault(
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
faulted_motors = []
for motor_name in self.motors:
has_fault, msg = self._query_status_via_clear_fault(motor_name)
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
if msg is None:
missing_motors.append(motor_name)
elif has_fault:
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
return responses
def _recv_all_messages_until_quiet(
self,
*,
timeout: float = RUNNING_TIMEOUT,
max_messages: int = 4096,
) -> list[can.Message]:
"""
Receive frames until the bus goes quiet.
Args:
timeout: Poll timeout used for each recv() call. Collection stops
when one recv() times out (quiet gap).
max_messages: Safety cap to prevent unbounded loops.
"""
out: list[can.Message] = []
max_messages = max(1, max_messages)
timeout = max(0.0, timeout)
try:
while len(out) < max_messages:
msg = self._bus().recv(timeout=timeout)
if msg is None:
break
out.append(msg)
except (can.CanError, OSError) as e:
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
return out
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
"""
Decode all received feedback frames and update cached motor states.
Returns:
Set of payload recv_ids that were successfully mapped to motors.
"""
processed_recv_ids: set[int] = set()
for msg in messages:
if len(msg.data) < 1:
logger.debug(
f"Dropping short CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
)
continue
recv_id = int(msg.data[0])
motor_name = self._recv_id_to_motor.get(recv_id)
if motor_name is None:
logger.debug(
f"Unmapped CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
)
continue
self._process_response(motor_name, msg)
processed_recv_ids.add(recv_id)
return processed_recv_ids
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
"""
Drain pending RX frames from the CAN interface.
This is used by higher-level controllers to drop stale feedback before issuing
a fresh read cycle, so subsequent state reads are based on most recent replies.
It should also be called once when a controller instance is created/connected,
to clear residual frames left on the interface from previous sessions.
"""
drained = 0
poll_timeout_s = max(0.0, poll_timeout_s)
max_messages = max(1, max_messages)
try:
while drained < max_messages:
msg = self._bus().recv(timeout=poll_timeout_s)
if msg is None:
break
drained += 1
except (can.CanError, OSError) as e:
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
return drained
def _speed_control(
self,
motor: NameOrID,
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Read every feedback frame until RX goes quiet, then decode all of them.
# This avoids dropping useful frames when responses from different motors interleave.
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
try:
self._decode_motor_state(msg.data)
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
logger.warning(
f"Failed to decode response from {motor} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the state cache."""
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
self._bus().send(msg)
updated_motors.append(motor)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
for response in responses.values():
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
if payload_motor_name is not None:
self._process_response(payload_motor_name, response)
else:
# Fallback: still attempt to decode based on payload byte0 mapping.
self._decode_motor_state(response.data)
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
for motor in updated_motors:
recv_id = self._get_motor_recv_id(motor)
if recv_id not in responses:
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def read_calibration(self) -> dict[str, MotorCalibration]:

View File

@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.001
RUNNING_TIMEOUT = 0.003
HANDSHAKE_TIMEOUT_S = 0.05
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02

View File

@@ -14,7 +14,7 @@
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict
action_head_cfg: dict
action_horizon: int
action_dim: int
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -258,15 +265,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -274,13 +282,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -255,15 +262,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -271,13 +279,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,

View File

@@ -95,6 +95,13 @@ from .relative_action_processor import (
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
# `lerobot.datasets.language`, which requires the `[dataset]` extra
# (`datasets`, `pyarrow`). Importing it from the processor package would
# break every base-install consumer of `lerobot.processor`. Users that
# need it import directly:
# from lerobot.processor.render_messages_processor import RenderMessagesStep
__all__ = [
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",

View File

@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(

View File

@@ -153,26 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
return x
_COMPLEMENTARY_KEYS = (
"task",
"index",
"task_index",
"episode_index",
"timestamp",
"language_persistent",
"language_events",
"messages",
"message_streams",
"target_message_indices",
)
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
Extract complementary data from a batch dictionary.
"""Extract complementary data from a batch dictionary.
This includes padding flags, task description, and indices.
Args:
batch: The batch dictionary.
Returns:
A dictionary with the extracted complementary data.
Includes padding flags (any key containing ``_is_pad``) plus the fixed
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
each only when present in ``batch``.
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
return {**pad_keys, **extras}
def create_transition(

View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.utils import unwrap_scalar
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=unwrap_scalar(timestamp),
sample_idx=int(unwrap_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_rebot_b601_follower import BiRebotB601Follower
from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
__all__ = ["BiRebotB601Follower", "BiRebotB601FollowerConfig"]

View File

@@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot
from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__)
class BiRebotB601Follower(Robot):
"""Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
action keys of each arm are namespaced with a ``left_`` / ``right_`` prefix.
"""
config_class = BiRebotB601FollowerConfig
name = "bi_rebot_b601_follower"
def __init__(self, config: BiRebotB601FollowerConfig):
super().__init__(config)
self.config = config
left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
can_adapter=config.left_arm_config.can_adapter,
dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
joint_limits=config.left_arm_config.joint_limits,
)
right_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
can_adapter=config.right_arm_config.can_adapter,
dm_serial_baud=config.right_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
motor_can_ids=config.right_arm_config.motor_can_ids,
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
joint_limits=config.right_arm_config.joint_limits,
)
self.left_arm = RebotB601Follower(left_arm_config)
self.right_arm = RebotB601Follower(right_arm_config)
# Only for compatibility with parts of the codebase that expect `robot.cameras`.
self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras}
@property
def _motors_ft(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
left_action = {
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
}
right_action = {
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
}
sent_action_left = self.left_arm.send_action(left_action)
sent_action_right = self.right_arm.send_action(right_action)
return {
**{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()},
}
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig
@RobotConfig.register_subclass("bi_rebot_b601_follower")
@dataclass
class BiRebotB601FollowerConfig(RobotConfig):
"""Configuration class for the bimanual reBot B601-DM follower robot."""
left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_rebot_b601_follower import RebotB601FollowerConfig, RebotB601FollowerRobotConfig
from .rebot_b601_follower import RebotB601Follower
__all__ = ["RebotB601Follower", "RebotB601FollowerConfig", "RebotB601FollowerRobotConfig"]

View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
@dataclass
class RebotB601FollowerConfig:
"""Base configuration class for the Seeed Studio reBot B601-DM follower arm.
The B601-DM is a 6-DOF arm plus gripper driven by Damiao CAN motors. Motor
communication goes through the ``motorbridge`` package.
"""
# Communication port. For ``can_adapter="damiao"`` this is the Damiao serial
# bridge device (e.g. "/dev/ttyACM0"); for ``can_adapter="socketcan"`` it is
# the CAN channel name (e.g. "can0").
port: str
# CAN adapter type:
# "damiao" - Damiao dedicated serial bridge (default)
# "socketcan" - SocketCAN based adapters (PCAN, slcan, embedded controllers, ...)
can_adapter: str = "damiao"
# Baud rate for the Damiao serial bridge (only used when can_adapter="damiao").
dm_serial_baud: int = 921600
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target
# vector for safety purposes (in degrees). Set to a positive scalar to apply the
# same value to all motors, or to a dict mapping motor names to per-motor values.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Maps motor names to their (send_can_id, recv_can_id) pair.
motor_can_ids: dict[str, tuple[int, int]] = field(
default_factory=lambda: {
"shoulder_pan": (0x01, 0x11),
"shoulder_lift": (0x02, 0x12),
"elbow_flex": (0x03, 0x13),
"wrist_flex": (0x04, 0x14),
"wrist_yaw": (0x05, 0x15),
"wrist_roll": (0x06, 0x16),
"gripper": (0x07, 0x17),
}
)
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
# applied to every joint; a list provides one value per joint (in motor order).
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
gripper_torque_ratio: float = 0.1
# Soft joint limits (degrees). These are clipped against on every action.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"shoulder_pan": (-145.0, 145.0),
"shoulder_lift": (-170.0, 1.0),
"elbow_flex": (-200.0, 1.0),
"wrist_flex": (-80.0, 90.0),
"wrist_yaw": (-90.0, 90.0),
"wrist_roll": (-90.0, 90.0),
"gripper": (-270.0, 0.0),
}
)
@RobotConfig.register_subclass("rebot_b601_follower")
@dataclass
class RebotB601FollowerRobotConfig(RobotConfig, RebotB601FollowerConfig):
"""Registered configuration for the reBot B601-DM follower robot."""
pass

View File

@@ -0,0 +1,289 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import time
from functools import cached_property
from typing import TYPE_CHECKING
from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import MotorCalibration
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _motorbridge_available, require_package
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_rebot_b601_follower import RebotB601FollowerRobotConfig
if TYPE_CHECKING or _motorbridge_available:
from motorbridge import Controller as MotorBridgeController, Mode as MotorBridgeMode
else:
MotorBridgeController = None
MotorBridgeMode = None
logger = logging.getLogger(__name__)
# Joint controlled in FORCE_POS mode; every other joint runs in POS_VEL mode.
GRIPPER_MOTOR = "gripper"
# Per-joint Damiao motor models for the B601-DM (passed to motorbridge).
MOTOR_MODELS = {
"shoulder_pan": "4340P",
"shoulder_lift": "4340P",
"elbow_flex": "4340P",
"wrist_flex": "4310",
"wrist_yaw": "4310",
"wrist_roll": "4310",
"gripper": "4310",
}
_ENSURE_MODE_RETRIES = 9
_SETTLE_SEC = 0.01
_ZERO_SETTLE_SEC = 0.1
class RebotB601Follower(Robot):
"""Seeed Studio reBot B601-DM follower arm (6-DOF + gripper, Damiao CAN motors).
Motor communication is handled by the ``motorbridge`` package over a CAN bus,
reached either through a Damiao serial bridge or a SocketCAN adapter.
"""
config_class = RebotB601FollowerRobotConfig
name = "rebot_b601_follower"
def __init__(self, config: RebotB601FollowerRobotConfig):
require_package("motorbridge", extra="rebot")
super().__init__(config)
self.config = config
self.bus: MotorBridgeController | None = None
self.motors: dict = {}
self.motor_names = list(config.motor_can_ids.keys())
self.cameras = make_cameras_from_configs(config.cameras)
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.motor_names}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.bus is not None and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting {self} on {self.config.port} (adapter={self.config.can_adapter})...")
if self.config.can_adapter == "damiao":
self.bus = MotorBridgeController.from_dm_serial(
serial_port=self.config.port,
baud=self.config.dm_serial_baud,
)
elif self.config.can_adapter == "socketcan":
self.bus = MotorBridgeController(channel=self.config.port)
else:
raise ValueError(
f"Unsupported can_adapter '{self.config.can_adapter}'. Use 'damiao' or 'socketcan'."
)
for motor_name, (send_id, recv_id) in self.config.motor_can_ids.items():
self.motors[motor_name] = self.bus.add_damiao_motor(send_id, recv_id, MOTOR_MODELS[motor_name])
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return bool(self.calibration)
def calibrate(self) -> None:
if self.calibration:
user_input = input(
f"Press ENTER to use provided calibration file associated with the id {self.id}, "
"or type 'c' and press ENTER to run calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using calibration file associated with the id {self.id}")
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_all()
print(
"\nCalibration: set zero position.\n"
"Manually move the reBot B601 to its ZERO POSITION and close the gripper.\n"
"See the B601 manual for the zero pose (the default sit-down position).\n"
)
input("Press ENTER when ready...")
for motor in self.motors.values():
motor.set_zero_position()
time.sleep(_ZERO_SETTLE_SEC)
logger.info("Arm zero position set.")
self.calibration = {}
for motor_name, (send_id, _recv_id) in self.config.motor_can_ids.items():
range_min, range_max = self.config.joint_limits[motor_name]
self.calibration[motor_name] = MotorCalibration(
id=send_id,
drive_mode=0,
homing_offset=0,
range_min=int(range_min),
range_max=int(range_max),
)
self._save_calibration()
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
self.bus.enable_all()
for motor_name, motor in self.motors.items():
target_mode = (
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
)
for attempt in range(_ENSURE_MODE_RETRIES + 1):
try:
motor.ensure_mode(target_mode)
break
except Exception:
if attempt == _ENSURE_MODE_RETRIES:
raise
time.sleep(_SETTLE_SEC)
logger.debug(f"{motor_name} mode set to {target_mode}")
@check_if_not_connected
def disable_torque(self) -> None:
"""Disable motor torque so the arm can be moved by hand (read-only debugging)."""
self.bus.disable_all()
logger.info(f"{self} torque disabled.")
def _present_pos(self) -> dict[str, float]:
"""Read present joint positions in degrees."""
for motor in self.motors.values():
motor.request_feedback()
try:
self.bus.poll_feedback_once()
except Exception:
logger.warning("CAN bus poll feedback failed.")
present_pos = {}
for motor_name, motor in self.motors.items():
state = motor.get_state()
present_pos[motor_name] = math.degrees(state.pos) if state is not None else 0.0
return present_pos
@check_if_not_connected
def get_observation(self) -> RobotObservation:
start = time.perf_counter()
obs_dict = {f"{motor}.pos": pos for motor, pos in self._present_pos().items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.read_latest()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
"""Command the arm to a target joint configuration.
Positions are expressed in degrees. The relative action magnitude may be
clipped depending on `max_relative_target`, so the action actually sent is
always returned.
"""
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
# Clip against soft joint limits.
for motor_name in list(goal_pos):
if motor_name in self.config.joint_limits:
min_limit, max_limit = self.config.joint_limits[motor_name]
clipped = max(min_limit, min(max_limit, goal_pos[motor_name]))
if clipped != goal_pos[motor_name]:
logger.debug(f"Clipped {motor_name} from {goal_pos[motor_name]:.2f} to {clipped:.2f}")
goal_pos[motor_name] = clipped
# Tolerate 6-DOF leaders that have no wrist_yaw joint by holding it at zero.
# This is intentional: it lets a 6-DOF leader such as the SO-100 / SO-101
# (so100_leader / so101_leader) teleoperate this 7-DOF follower — the missing
# wrist_yaw command is simply treated as 0.0 instead of raising.
if "wrist_yaw" not in goal_pos:
goal_pos["wrist_yaw"] = 0.0
# Cap relative target when too far from the present position.
if self.config.max_relative_target is not None:
present_pos = self._present_pos()
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
for motor_name, position_deg in goal_pos.items():
motor = self.motors.get(motor_name)
if motor is None:
continue
idx = self.motor_names.index(motor_name)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
pos_rad = math.radians(position_deg)
vel_rad = math.radians(vel_deg_s)
if motor_name == GRIPPER_MOTOR:
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
else:
motor.send_pos_vel(pos_rad, vel_rad)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self) -> None:
for motor in self.motors.values():
if self.config.disable_torque_on_disconnect:
motor.disable()
motor.clear_error()
motor.close()
self.bus.close()
self.bus = None
self.motors = {}
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -68,12 +68,9 @@ class SOFollower(Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
features: dict[str, tuple] = {}
for cam in self.cameras:
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
if getattr(self.cameras[cam], "use_depth", False):
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
return features
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -193,12 +190,6 @@ class SOFollower(Robot):
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
if getattr(cam, "use_depth", False):
start = time.perf_counter()
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected

View File

@@ -68,6 +68,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .bi_openarm_follower import BiOpenArmFollower
return BiOpenArmFollower(config)
elif config.type == "rebot_b601_follower":
from .rebot_b601_follower import RebotB601Follower
return RebotB601Follower(config)
elif config.type == "bi_rebot_b601_follower":
from .bi_rebot_b601_follower import BiRebotB601Follower
return BiRebotB601Follower(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot

View File

@@ -333,7 +333,6 @@ def build_rollout_context(
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
@@ -369,7 +368,6 @@ def build_rollout_context(
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,

View File

@@ -39,6 +39,7 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_rebot_b601_follower,
bi_so_follower,
hope_jr,
koch_follower,
@@ -46,12 +47,14 @@ from lerobot.robots import ( # noqa: F401
make_robot_from_config,
omx_follower,
openarm_follower,
rebot_b601_follower,
so_follower,
)
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
koch_leader,
@@ -59,6 +62,7 @@ from lerobot.teleoperators import ( # noqa: F401
omx_leader,
openarm_leader,
openarm_mini,
rebot_102_leader,
so_leader,
unitree_g1,
)

View File

@@ -178,6 +178,31 @@ Recompute stats for relative actions and push to hub:
--operation.num_workers 4 \
--push_to_hub true
Re-encode all videos in a dataset (saves to lerobot/pusht_reencoded by default):
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec h264 \
--operation.camera_encoder.pix_fmt yuv420p \
--operation.camera_encoder.crf 23
Re-encode videos into a new dataset using 4 parallel processes:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_h264 \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec h264 \
--operation.camera_encoder.crf 23 \
--operation.num_workers 4
Re-encode videos in-place (overwrites original dataset):
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec h264 \
--operation.overwrite true
Using JSON config file:
lerobot-edit-dataset \
--config_path path/to/edit_config.json
@@ -200,6 +225,7 @@ from lerobot.datasets import (
merge_datasets,
modify_tasks,
recompute_stats,
reencode_dataset,
remove_feature,
split_dataset,
)
@@ -268,6 +294,15 @@ class RecomputeStatsConfig(OperationConfig):
overwrite: bool = False
@OperationConfig.register_subclass("reencode_videos")
@dataclass
class ReencodeVideosConfig(OperationConfig):
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
num_workers: int = 0
encoder_threads: int | None = None
overwrite: bool = False
@OperationConfig.register_subclass("info")
@dataclass
class InfoConfig(OperationConfig):
@@ -634,6 +669,58 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
dataset.push_to_hub()
def handle_reencode_videos(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ReencodeVideosConfig):
raise ValueError("Operation config must be ReencodeVideosConfig")
output_repo_id, input_root, output_root = _resolve_io_paths(
cfg.repo_id,
cfg.new_repo_id,
cfg.root,
cfg.new_root,
default_new_repo_id=f"{cfg.repo_id}_reencoded",
)
in_place = output_root == input_root
if in_place and not cfg.operation.overwrite:
raise ValueError(
f"reencode_videos would overwrite the dataset in-place at {input_root}. "
"Pass --operation.overwrite true to allow in-place modification, "
"or use --new_repo_id / --new_root to write to a different location. "
f"Default output repo_id when neither is set: '{cfg.repo_id}_reencoded'."
)
if in_place:
logging.warning(
f"Overwriting dataset videos in-place at {input_root}. The original videos will be lost."
)
dataset = LeRobotDataset(cfg.repo_id, root=input_root)
else:
logging.info(f"Copying dataset from {input_root} to {output_root}")
if output_root.exists():
backup_path = output_root.with_name(output_root.name + "_old")
logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}")
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(output_root, backup_path)
shutil.copytree(input_root, output_root)
dataset = LeRobotDataset(output_repo_id, root=output_root)
logging.info(f"Re-encoding videos in {output_repo_id} with {cfg.operation.camera_encoder}")
reencode_dataset(
dataset,
camera_encoder=cfg.operation.camera_encoder,
encoder_threads=cfg.operation.encoder_threads,
num_workers=cfg.operation.num_workers,
)
logging.info(f"All videos re-encoded at {dataset.root}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}...")
dataset.push_to_hub()
def _get_dataset_size(repo_path):
import os
@@ -707,6 +794,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_convert_image_to_video(cfg)
elif operation_type == "recompute_stats":
handle_recompute_stats(cfg)
elif operation_type == "reencode_videos":
handle_reencode_videos(cfg)
elif operation_type == "info":
handle_info(cfg)
else:

View File

@@ -45,16 +45,19 @@ from lerobot.model import RobotKinematics
from lerobot.robots import ( # noqa: F401
RobotConfig,
bi_openarm_follower,
bi_rebot_b601_follower,
bi_so_follower,
koch_follower,
make_robot_from_config,
omx_follower,
openarm_follower,
rebot_b601_follower,
so_follower,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_leader,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
koch_leader,
@@ -62,6 +65,7 @@ from lerobot.teleoperators import ( # noqa: F401
omx_leader,
openarm_leader,
openarm_mini,
rebot_102_leader,
so_leader,
)
from lerobot.utils.robot_utils import precise_sleep

View File

@@ -120,6 +120,7 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_rebot_b601_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
@@ -128,6 +129,7 @@ from lerobot.robots import ( # noqa: F401
omx_follower,
openarm_follower,
reachy2,
rebot_b601_follower,
so_follower,
unitree_g1 as unitree_g1_robot,
)
@@ -135,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
koch_leader,
@@ -143,6 +146,7 @@ from lerobot.teleoperators import ( # noqa: F401
openarm_leader,
openarm_mini,
reachy2_teleoperator,
rebot_102_leader,
so_leader,
unitree_g1,
)
@@ -399,7 +403,6 @@ def record(
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
encoder_threads=cfg.dataset.encoder_threads,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
@@ -429,7 +432,6 @@ def record(
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
camera_encoder=cfg.dataset.camera_encoder,
depth_encoder=cfg.dataset.depth_encoder,
encoder_threads=cfg.dataset.encoder_threads,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,

View File

@@ -56,6 +56,7 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_rebot_b601_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
@@ -64,6 +65,7 @@ from lerobot.robots import ( # noqa: F401
omx_follower,
openarm_follower,
reachy2,
rebot_b601_follower,
so_follower,
unitree_g1,
)

View File

@@ -144,6 +144,7 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_rebot_b601_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
@@ -151,6 +152,7 @@ from lerobot.robots import ( # noqa: F401
omx_follower,
openarm_follower,
reachy2,
rebot_b601_follower,
so_follower,
unitree_g1 as unitree_g1_robot,
)
@@ -159,6 +161,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
koch_leader,
@@ -166,6 +169,7 @@ from lerobot.teleoperators import ( # noqa: F401
openarm_leader,
openarm_mini,
reachy2_teleoperator,
rebot_102_leader,
so_leader,
unitree_g1,
)

View File

@@ -30,20 +30,24 @@ import draccus
from lerobot.robots import ( # noqa: F401
RobotConfig,
bi_rebot_b601_follower,
bi_so_follower,
koch_follower,
lekiwi,
make_robot_from_config,
omx_follower,
rebot_b601_follower,
so_follower,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_rebot_102_leader,
bi_so_leader,
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_mini,
rebot_102_leader,
so_leader,
)

View File

@@ -72,6 +72,7 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_rebot_b601_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
@@ -80,6 +81,7 @@ from lerobot.robots import ( # noqa: F401
omx_follower,
openarm_follower,
reachy2,
rebot_b601_follower,
so_follower,
unitree_g1 as unitree_g1_robot,
)
@@ -87,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
homunculus,
@@ -97,6 +100,7 @@ from lerobot.teleoperators import ( # noqa: F401
openarm_leader,
openarm_mini,
reachy2_teleoperator,
rebot_102_leader,
so_leader,
unitree_g1,
)

View File

@@ -48,6 +48,7 @@ from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -401,6 +402,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
shuffle = True
sampler = None
# Only swap in the language-aware collate when the dataset actually
# declares language columns; otherwise stay on PyTorch's default
# collate so non-language training runs are unaffected.
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
@@ -409,6 +414,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_rebot_102_leader import BiRebotArm102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]

View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
logger = logging.getLogger(__name__)
class BiRebotArm102Leader(Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
each arm are namespaced with a ``left_`` / ``right_`` prefix, so a bimanual
leader can teleoperate a bimanual reBot B601 follower.
"""
config_class = BiRebotArm102LeaderConfig
name = "bi_rebot_102_leader"
def __init__(self, config: BiRebotArm102LeaderConfig):
super().__init__(config)
self.config = config
left_arm_config = RebotArm102LeaderTeleopConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
baudrate=config.left_arm_config.baudrate,
joint_ids=config.left_arm_config.joint_ids,
joint_directions=config.left_arm_config.joint_directions,
joint_ranges=config.left_arm_config.joint_ranges,
)
right_arm_config = RebotArm102LeaderTeleopConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
baudrate=config.right_arm_config.baudrate,
joint_ids=config.right_arm_config.joint_ids,
joint_directions=config.right_arm_config.joint_directions,
joint_ranges=config.right_arm_config.joint_ranges,
)
self.left_arm = RebotArm102Leader(left_arm_config)
self.right_arm = RebotArm102Leader(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
action_dict.update({f"left_{k}": v for k, v in self.left_arm.get_action().items()})
action_dict.update({f"right_{k}": v for k, v in self.right_arm.get_action().items()})
return action_dict
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig
right_arm_config: RebotArm102LeaderConfig

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_rebot_102_leader import RebotArm102LeaderConfig, RebotArm102LeaderTeleopConfig
from .rebot_102_leader import RebotArm102Leader
__all__ = ["RebotArm102Leader", "RebotArm102LeaderConfig", "RebotArm102LeaderTeleopConfig"]

View File

@@ -0,0 +1,83 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from ..config import TeleoperatorConfig
@dataclass
class RebotArm102LeaderConfig:
"""Base configuration class for the Seeed Studio StarArm102 / reBot Arm 102 leader.
The reBot Arm 102 is a 7-joint (incl. gripper) leader arm driven by FashionStar
UART smart servos. Servo communication goes through ``motorbridge-smart-servo``.
"""
# USB-to-UART device the leader arm is connected to (e.g. "/dev/ttyUSB0").
port: str
baudrate: int = 1_000_000
# Servo id of each joint on the UART bus.
joint_ids: dict[str, int] = field(
default_factory=lambda: {
"shoulder_pan": 0,
"shoulder_lift": 1,
"elbow_flex": 2,
"wrist_flex": 3,
"wrist_yaw": 4,
"wrist_roll": 5,
"gripper": 6,
}
)
# Per-joint sign applied to raw servo angles so the leader matches the follower
# convention. The gripper additionally carries a scale (e.g. -6) to widen its
# range to the reBot B601 follower's gripper travel.
joint_directions: dict[str, int] = field(
default_factory=lambda: {
"shoulder_pan": -1,
"shoulder_lift": -1,
"elbow_flex": 1,
"wrist_flex": 1,
"wrist_yaw": 1,
"wrist_roll": -1,
"gripper": -6,
}
)
# Per-joint [min, max] output range in degrees. Matches the reBot B601 follower
# joint limits so leader actions can drive the follower key-for-key.
joint_ranges: dict[str, list[int]] = field(
default_factory=lambda: {
"shoulder_pan": [-150, 150],
"shoulder_lift": [-170, 1],
"elbow_flex": [-200, 1],
"wrist_flex": [-80, 90],
"wrist_yaw": [-90, 90],
"wrist_roll": [-90, 90],
"gripper": [-270, 0],
}
)
@TeleoperatorConfig.register_subclass("rebot_102_leader")
@dataclass
class RebotArm102LeaderTeleopConfig(TeleoperatorConfig, RebotArm102LeaderConfig):
"""Registered configuration for the reBot Arm 102 leader teleoperator."""
pass

View File

@@ -0,0 +1,207 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import TYPE_CHECKING
from lerobot.motors import MotorCalibration
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _motorbridge_smart_servo_available, require_package
from ..teleoperator import Teleoperator
from .config_rebot_102_leader import RebotArm102LeaderTeleopConfig
if TYPE_CHECKING or _motorbridge_smart_servo_available:
from motorbridge_smart_servo import FashionStarServo, ServoMonitor
else:
FashionStarServo = None
ServoMonitor = None
logger = logging.getLogger(__name__)
_SETTLE_SEC = 0.01
class RebotArm102Leader(Teleoperator):
"""Seeed Studio StarArm102 / reBot Arm 102 leader arm.
A 7-joint (incl. gripper) leader built on FashionStar UART smart servos. Servo
communication is handled by the ``motorbridge-smart-servo`` package; this class
only reads joint angles, so it produces actions but accepts no feedback.
"""
config_class = RebotArm102LeaderTeleopConfig
name = "rebot_102_leader"
def __init__(self, config: RebotArm102LeaderTeleopConfig):
require_package("motorbridge-smart-servo", extra="rebot", import_name="motorbridge_smart_servo")
super().__init__(config)
self.config = config
self.bus: FashionStarServo | None = None
self.motor_names = list(config.joint_ids.keys())
self._last_raw_positions: dict[str, float] = {}
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.motor_names}
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus is not None
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting {self} on {self.config.port}...")
bus = FashionStarServo(self.config.port, baudrate=self.config.baudrate)
try:
for motor_name, motor_id in self.config.joint_ids.items():
if not bus.ping(motor_id):
raise RuntimeError(f"Servo not found for {motor_name} (id={motor_id}).")
self._last_raw_positions[motor_name] = 0.0
self.bus = bus
if not self.is_calibrated and calibrate:
logger.info(
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
)
self.calibrate()
self.configure()
except Exception:
bus.close()
self.bus = None
raise
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return bool(self.calibration) and set(self.calibration) == set(self.motor_names)
def calibrate(self) -> None:
if self.calibration:
user_input = input(
f"Press ENTER to use provided calibration file associated with the id {self.id}, "
"or type 'c' and press ENTER to run calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using calibration file associated with the id {self.id}")
return
logger.info(f"\nRunning calibration of {self}")
input(
"\nCalibration: set zero position.\n"
"Manually move the reBot Arm 102 to its zero pose and close the gripper.\n"
"Press ENTER when ready..."
)
self.calibration = {}
for motor_name, motor_id in self.config.joint_ids.items():
self.bus.unlock(motor_id)
time.sleep(_SETTLE_SEC)
self.bus.set_origin_point(motor_id)
range_min, range_max = self.config.joint_ranges[motor_name]
self.calibration[motor_name] = MotorCalibration(
id=motor_id,
drive_mode=0,
homing_offset=0,
range_min=int(range_min),
range_max=int(range_max),
)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
for motor_id in self.config.joint_ids.values():
self.bus.unlock(motor_id)
time.sleep(_SETTLE_SEC)
# Reset the multi-turn counter of each servo individually.
for motor_id in self.config.joint_ids.values():
self.bus.reset_multi_turn(motor_id)
def _read_raw_positions(self) -> dict[str, float]:
result: dict[int, ServoMonitor | None] = self.bus.sync_monitor(list(self.config.joint_ids.values()))
id_to_name = {v: k for k, v in self.config.joint_ids.items()}
raw_positions: dict[str, float] = {}
for motor_id, monitor in result.items():
motor_name = id_to_name[motor_id]
if monitor is None:
raise RuntimeError(f"Servo {motor_name} (id={motor_id}) has never responded.")
raw_positions[motor_name] = monitor.angle_deg
return raw_positions
@staticmethod
def _round_to_valid_range(value: float, min_value: float, max_value: float) -> tuple[float, int]:
"""Unwrap a multi-turn angle into the ±180° window centred on (min+max)/2.
The servo may report an angle that has accumulated extra full rotations
(value = true_angle + N*360). Subtract the nearest whole number of turns
to bring it back into [center-180, center+180]. Returns the unwrapped
angle and the number of turns removed.
"""
center = (min_value + max_value) / 2.0
turns = round((value - center) / 360.0)
return value - turns * 360.0, abs(turns)
@check_if_not_connected
def get_action(self) -> RobotAction:
start = time.perf_counter()
try:
raw_positions = self._read_raw_positions()
self._last_raw_positions = raw_positions
except Exception as e:
logger.error(f"Failed to read raw positions: {e}")
logger.warning("[EMERGENCY STOP] Hold the follower arm and cut off the main power to the arms.")
logger.warning(
"[EMERGENCY STOP] Break the teleoperation session and check the leader USB connection or power."
)
raw_positions = self._last_raw_positions
action_dict: dict[str, float] = {}
for motor_name in self.motor_names:
range_min, range_max = self.config.joint_ranges[motor_name]
direction = self.config.joint_directions[motor_name]
sign = 1.0 if direction >= 0 else -1.0
unwrapped, k = self._round_to_valid_range(
raw_positions[motor_name], range_min * sign, range_max * sign
)
position = unwrapped * direction
if k > 0:
logger.debug(
f"Servo {motor_name} (id={self.config.joint_ids[motor_name]}) wrapped {k} * 360°. "
f"Unwrapped pos: {unwrapped:.1f}° (raw: {raw_positions[motor_name]:.1f}°)"
)
action_dict[f"{motor_name}.pos"] = max(float(range_min), min(float(range_max), position))
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action_dict
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.bus.close()
self.bus = None
logger.info(f"{self} disconnected.")

View File

@@ -99,6 +99,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebotArm102Leader
return BiRebotArm102Leader(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))

View File

@@ -0,0 +1,65 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
# carry `messages` and half don't) is a real bug in the upstream
# rendering step — silently filtering would hand downstream consumers a
# preserved list shorter than the tensor batch. Raise instead so the
# mismatch surfaces at the boundary.
preserved: dict[str, list[Any]] = {}
for key in _PYTHON_LIST_KEYS:
presence = [key in sample for sample in batch]
if not any(presence):
continue
if not all(presence):
raise ValueError(
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
f"every sample in a batch must agree."
)
preserved[key] = [sample[key] for sample in batch]
tensorizable = [
{
key: value
for key, value in sample.items()
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
}
for sample in batch
]
collated = default_collate(tensorizable)
collated.update(preserved)
return collated

View File

@@ -69,7 +69,6 @@ def hw_to_dataset_features(
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
# TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
@@ -87,19 +86,11 @@ def hw_to_dataset_features(
}
for key, shape in cam_fts.items():
dtype = "video" if use_video else "image"
if len(shape) == 3 and shape[2] in (1, 3):
features[f"{prefix}.images.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
"info": {"is_depth_map": shape[2] == 1},
}
else:
raise ValueError(
f"Camera feature '{key}' has shape {shape}. "
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
)
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
@@ -158,11 +149,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
else:
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):

View File

@@ -114,6 +114,10 @@ _dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dy
_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
_motorbridge_available = is_package_available("motorbridge")
_motorbridge_smart_servo_available = is_package_available(
"motorbridge-smart-servo", import_name="motorbridge_smart_servo"
)
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
_pyrealsense2_available = is_package_available("pyrealsense2") or is_package_available(
"pyrealsense2-macosx", import_name="pyrealsense2"

View File

@@ -160,6 +160,25 @@ def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
def unwrap_scalar(value: Any) -> Any:
"""Unwrap a tensor / numpy scalar / single-element list into a Python scalar.
Tensors and numpy scalars expose ``.item()``; single-element lists are
unwrapped recursively. Anything else is returned unchanged. Centralized
here so the language renderer and processor steps share one definition.
Raises:
ValueError: If ``value`` is a list with zero or multiple elements.
"""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return unwrap_scalar(value[0])
return value
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.

View File

@@ -107,15 +107,8 @@ def log_rerun_data(
for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
else:
if arr.shape[-1] == 1:
img_entity = (
rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis).compress()
if compress_images
else rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity)
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
if action:
for k, v in action.items():

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python
from pathlib import Path
from textwrap import dedent
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
return MessageTurn(role="user", content=content, stream="high_level")
def _minimal_target_turn() -> MessageTurn:
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
# ── Message-recipe validation ────────────────────────────────────────
def test_message_recipe_validates_unknown_binding():
with pytest.raises(ValueError, match="unknown binding"):
TrainingRecipe(
messages=[
MessageTurn(role="user", content="${missing}", stream="high_level"),
_minimal_target_turn(),
]
)
def test_message_turn_requires_a_stream():
"""Every turn must declare a stream — None is rejected at construction.
Previously this only failed at render time (``_validate_rendered``);
catching it here means a malformed recipe YAML errors at load instead
of at the first training sample.
"""
with pytest.raises(ValueError, match="missing a stream"):
MessageTurn(role="user", content="${task}")
def test_message_recipe_requires_at_least_one_target():
with pytest.raises(ValueError, match="target"):
TrainingRecipe(
messages=[
_minimal_message_turn(),
MessageTurn(role="assistant", content="no target", stream="high_level"),
]
)
def test_recipe_rejects_both_messages_and_blend():
with pytest.raises(ValueError, match="only one"):
TrainingRecipe(
messages=[_minimal_message_turn(), _minimal_target_turn()],
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
)
def test_recipe_rejects_neither_messages_nor_blend():
with pytest.raises(ValueError, match="must set one"):
TrainingRecipe()
# ── Blend validation ─────────────────────────────────────────────────
def test_blend_must_be_non_empty():
with pytest.raises(ValueError, match="at least one component"):
TrainingRecipe(blend={})
def test_blend_component_must_define_weight():
with pytest.raises(ValueError, match="weight"):
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
def test_blend_component_weight_must_be_positive():
with pytest.raises(ValueError, match="positive weight"):
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
def test_blend_component_must_define_messages():
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
# going through __post_init__ to exercise the blend-level validator.
bad = TrainingRecipe.__new__(TrainingRecipe)
bad.messages = None
bad.bindings = None
bad.blend = None
bad.weight = 1.0
with pytest.raises(ValueError, match="must define messages"):
TrainingRecipe(blend={"a": bad})
def test_blend_components_cannot_themselves_define_a_blend():
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
# Force-bypass the inner component's normal validation so the test
# exercises the outer blend's "no nested blends" rule directly.
nested = TrainingRecipe.__new__(TrainingRecipe)
nested.messages = None
nested.bindings = None
nested.blend = inner.blend
nested.weight = 1.0
with pytest.raises(ValueError, match="cannot itself define a blend"):
TrainingRecipe(blend={"outer": nested})
# ── from_dict / from_yaml round-trips ────────────────────────────────
def test_from_dict_with_nested_blend():
recipe = TrainingRecipe.from_dict(
{
"blend": {
"a": {
"weight": 1.0,
"messages": [
{"role": "user", "content": "${task}", "stream": "high_level"},
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
],
},
"b": {
"weight": 2.0,
"messages": [
{"role": "user", "content": "${task}", "stream": "high_level"},
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
],
},
}
}
)
assert recipe.blend is not None
assert set(recipe.blend) == {"a", "b"}
assert recipe.blend["b"].weight == 2.0
# Inner messages were promoted to MessageTurn instances.
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
yaml_text = dedent(
"""
bindings:
custom: "active_at(t, style=subtask)"
messages:
- {role: user, content: "${task}: ${custom}", stream: high_level}
- {role: assistant, content: "ok", stream: high_level, target: true}
"""
).strip()
path = tmp_path / "recipe.yaml"
path.write_text(yaml_text)
via_classmethod = TrainingRecipe.from_yaml(path)
via_helper = load_recipe(path)
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
assert via_classmethod.messages[1].target is True
# ``load_recipe`` is just a wrapper, but assert the two paths agree
# on the structural result so a future divergence is caught here.
assert via_helper.bindings == via_classmethod.bindings
assert len(via_helper.messages) == len(via_classmethod.messages)
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
path = tmp_path / "bad.yaml"
path.write_text("- just\n- a\n- list\n")
with pytest.raises(ValueError, match="mapping at the top level"):
TrainingRecipe.from_yaml(path)

View File

@@ -59,13 +59,11 @@ def _make_dummy_stats(features: dict) -> dict:
stats = {}
for key, ft in features.items():
if ft["dtype"] in ("image", "video"):
channels = ft["shape"][-1]
stat_shape = (channels, 1, 1)
stats[key] = {
"max": np.ones(stat_shape, dtype=np.float32),
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
"min": np.zeros(stat_shape, dtype=np.float32),
"std": np.full(stat_shape, 0.25, dtype=np.float32),
"max": np.ones((3, 1, 1), dtype=np.float32),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
"min": np.zeros((3, 1, 1), dtype=np.float32),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
"count": np.array([5]),
}
elif ft["dtype"] in ("float32", "float64", "int64"):
@@ -144,45 +142,6 @@ def test_create_without_videos_has_no_video_path(tmp_path):
assert meta.video_keys == []
@pytest.mark.parametrize(
("marker_field", "marker_key"),
[
("info", "is_depth_map"),
("info", "video.is_depth_map"),
("video_info", "video.is_depth_map"),
],
ids=["info.is_depth_map", "info.video.is_depth_map_legacy", "video_info.video.is_depth_map_legacy"],
)
def test_depth_keys_property_filters_by_marker(tmp_path, marker_field, marker_key):
"""``depth_keys`` recognises the canonical and the two legacy marker variants."""
depth_feature = {
"dtype": "video",
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
marker_field: {marker_key: True},
}
features = {
**VIDEO_FEATURES,
"observation.images.laptop_depth": depth_feature,
}
meta = LeRobotDatasetMetadata.create(
repo_id="test/depth_keys",
fps=DEFAULT_FPS,
features=features,
root=tmp_path / f"depth_keys_{marker_field}_{marker_key.replace('.', '_')}",
)
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
assert meta.depth_keys == ["observation.images.laptop_depth"]
def test_depth_keys_empty_when_no_marker(tmp_path):
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
)
assert meta.depth_keys == []
def test_create_raises_on_existing_directory(tmp_path):
"""create() raises if root directory already exists."""
root = tmp_path / "existing"
@@ -426,3 +385,140 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0
# ── Tools accessor ───────────────────────────────────────────────────
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "no_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
assert meta.tools == DEFAULT_TOOLS
# info.json on disk should NOT include a `tools` key for clean datasets
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk
def test_tools_reads_declared_tools_from_info_json(tmp_path):
"""A `tools` list written into info.json survives load → meta.tools.
Regression test for the bug where ``DatasetInfo.from_dict`` silently
dropped the ``tools`` key (no matching dataclass field), so
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
what was on disk.
"""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "with_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/with_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
info_path = root / INFO_PATH
with open(info_path) as f:
raw = json.load(f)
raw["tools"] = [custom_tool]
with open(info_path, "w") as f:
json.dump(raw, f)
# Reload info from disk and rebind it on the metadata object
meta.info = load_info(root)
assert meta.tools == [custom_tool]
def test_tools_round_trip_through_dataset_info(tmp_path):
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
from lerobot.datasets.utils import DatasetInfo
raw = {
"codebase_version": "v3.1",
"fps": 30,
"features": SIMPLE_FEATURES,
"tools": [{"type": "function", "function": {"name": "say"}}],
}
info = DatasetInfo.from_dict(raw)
assert info.tools == raw["tools"]
assert info.to_dict()["tools"] == raw["tools"]
def test_tools_setter_persists_to_info_json_and_reloads(tmp_path):
"""Assigning meta.tools writes info.json and reloads meta.info."""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "set_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/set_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
meta.tools = [custom_tool]
# In-memory metadata reflects the new catalog ...
assert meta.tools == [custom_tool]
assert meta.info.tools == [custom_tool]
# ... and a fresh read from disk agrees.
assert load_info(root).tools == [custom_tool]
def test_tools_setter_clears_key_when_set_to_none(tmp_path):
"""Setting meta.tools back to None drops the key and restores the default."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "clear_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/clear_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
meta.tools = [{"type": "function", "function": {"name": "say"}}]
meta.tools = None
assert meta.tools == DEFAULT_TOOLS
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk

View File

@@ -23,6 +23,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs import VideoEncoderConfig
from lerobot.datasets.dataset_tools import (
add_features,
@@ -31,9 +32,12 @@ from lerobot.datasets.dataset_tools import (
merge_datasets,
modify_features,
modify_tasks,
reencode_dataset,
remove_feature,
split_dataset,
)
from lerobot.datasets.io_utils import load_info
from tests.datasets.test_video_encoding import _add_frames, require_h264, require_libsvtav1
@pytest.fixture
@@ -1326,3 +1330,41 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
if output_dir.exists():
shutil.rmtree(output_dir)
# ─── reencode_dataset ─────────────────────────────────────────────────
@require_libsvtav1
@require_h264
def test_reencode_dataset_multi_key_multiprocessing(
tmp_path, empty_lerobot_dataset_factory, features_factory
):
"""Re-encode a two-camera dataset with num_workers=2 and verify metadata refresh."""
features = features_factory(use_videos=True)
initial_cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds",
features=features,
use_videos=True,
camera_encoder=initial_cfg,
)
_add_frames(dataset, num_frames=4)
dataset.save_episode()
_add_frames(dataset, num_frames=4)
dataset.save_episode()
dataset.finalize()
assert len(dataset.meta.video_keys) == 2
target_cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv420p")
result = reencode_dataset(dataset, camera_encoder=target_cfg, num_workers=2)
assert result is dataset
persisted_info = load_info(dataset.root)
for vk in dataset.meta.video_keys:
persisted_encoder = VideoEncoderConfig.from_video_info(persisted_info.features[vk].get("info", {}))
assert persisted_encoder == target_cfg

View File

@@ -53,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
# ── Existing encode_video_worker tests ───────────────────────────────
def test_encode_video_worker_forwards_video_encoder(tmp_path):
"""_encode_video_worker forwards video_encoder to encode_video_frames."""
def test_encode_video_worker_forwards_camera_encoder(tmp_path):
"""_encode_video_worker forwards camera_encoder to encode_video_frames."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
@@ -74,16 +74,16 @@ def test_encode_video_worker_forwards_video_encoder(tmp_path):
0,
tmp_path,
fps=30,
video_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
camera_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
encoder_threads=4,
)
assert captured_kwargs["video_encoder"].vcodec == "h264"
assert captured_kwargs["camera_encoder"].vcodec == "h264"
assert captured_kwargs["encoder_threads"] == 4
def test_encode_video_worker_default_video_encoder(tmp_path):
"""_encode_video_worker passes None video_encoder which encode_video_frames defaults."""
def test_encode_video_worker_default_camera_encoder(tmp_path):
"""_encode_video_worker passes None camera_encoder which encode_video_frames defaults."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
@@ -100,7 +100,7 @@ def test_encode_video_worker_default_video_encoder(tmp_path):
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30)
assert captured_kwargs["video_encoder"] is None
assert captured_kwargs["camera_encoder"] is None
assert captured_kwargs["encoder_threads"] is None

View File

@@ -24,6 +24,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import datasets
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@pytest.mark.parametrize(
"dtype,np_dtype,values,assert_fn",
[
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
],
ids=["float32", "int64", "bool"],
)
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
):
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
captured = {}
original_from_dict = datasets.Dataset.from_dict
def _from_dict_spy(cls, mapping, *args, **kwargs):
captured["state"] = mapping["state"]
return original_from_dict(mapping, *args, **kwargs)
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
dataset.save_episode()
dataset.finalize()
assert "state" in captured
assert isinstance(captured["state"], np.ndarray)
assert captured["state"].shape == (2,)
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
def test_set_image_transforms_applies_transparently(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
@@ -1480,15 +1516,10 @@ def test_valid_video_codecs_constant():
assert "h264" in VALID_VIDEO_CODECS
assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS
assert "ffv1" in VALID_VIDEO_CODECS
assert "auto" in VALID_VIDEO_CODECS
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
assert "h264_nvenc" in VALID_VIDEO_CODECS
assert "h264_vaapi" in VALID_VIDEO_CODECS
assert "h264_qsv" in VALID_VIDEO_CODECS
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
assert "hevc_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 11
assert len(VALID_VIDEO_CODECS) == 10
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):

View File

@@ -1,307 +0,0 @@
"""Tests for the depth-integration feature.
Covers quantization/dequantization round-trips (depth_utils), image writer
depth support (image_writer), hardware→dataset feature routing
(feature_utils), video info helpers (video_utils / configs.video), and
feature-to-file-format routing through the dataset writer.
Depth metadata detection on ``LeRobotDatasetMetadata.depth_keys`` (canonical
and legacy marker variants) lives in ``test_dataset_metadata.py``.
"""
from pathlib import Path
import numpy as np
import PIL.Image
import pytest
import torch
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
import av
from lerobot.configs import DepthEncoderConfig
from lerobot.configs.video import DEPTH_QMAX, VALID_VIDEO_CODECS
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
from lerobot.datasets.image_writer import (
image_array_to_pil_image,
save_kwargs_for_path,
write_image,
)
from lerobot.datasets.pyav_utils import get_pix_fmt_channels
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_DEPTH_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
)
H, W = 48, 64
DEPTH_MIN = 0.01
DEPTH_MAX = 10.0
# ── 1. Quantize / Dequantize round-trips ────────────────────────────
class TestQuantizeDequantize:
"""Core numerical tests for depth_utils.quantize_depth / dequantize_depth."""
def _make_depth_metres(self) -> np.ndarray:
"""Linearly-spaced float32 depth in metres covering the default range."""
return np.linspace(DEPTH_MIN, DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
def test_roundtrip_linear_metres(self):
depth = self._make_depth_metres()
quantized = quantize_depth(depth, use_log=False, video_backend=None)
recovered = dequantize_depth(quantized, use_log=False, output_unit="m")
assert recovered.shape == (H, W, 1), f"Expected (H,W,1), got {recovered.shape}"
assert recovered.dtype == np.float32
tol = (DEPTH_MAX - DEPTH_MIN) / DEPTH_QMAX
np.testing.assert_allclose(recovered[..., 0], depth, atol=tol + 1e-6)
def test_roundtrip_log_metres(self):
depth = self._make_depth_metres()
quantized = quantize_depth(depth, use_log=True, video_backend=None)
recovered = dequantize_depth(quantized, use_log=True, output_unit="m")
assert recovered.shape == (H, W, 1)
near = depth < 1.0
far = depth > 8.0
err_near = np.abs(recovered[..., 0][near] - depth[near])
err_far = np.abs(recovered[..., 0][far] - depth[far])
assert err_near.mean() < err_far.mean(), "Log quant should be more precise at close range"
def test_roundtrip_mm_uint16_input(self):
depth_mm = np.linspace(10, 10000, H * W, dtype=np.float64).reshape(H, W).astype(np.uint16)
quantized = quantize_depth(depth_mm, use_log=False, video_backend=None, input_unit="mm")
recovered = dequantize_depth(quantized, use_log=False, output_unit="mm")
assert recovered.dtype == np.uint16
tol_mm = (DEPTH_MAX - DEPTH_MIN) * 1000.0 / DEPTH_QMAX
np.testing.assert_allclose(
recovered[..., 0].astype(np.float64), depth_mm.astype(np.float64), atol=tol_mm + 1.0
)
def test_quantize_clamps_out_of_range(self):
depth = np.array([[0.001, 99.0]], dtype=np.float32)
quantized = quantize_depth(depth, use_log=False, video_backend=None)
assert quantized[0, 0] == 0
assert quantized[0, 1] == DEPTH_QMAX
def test_quantize_accepts_torch_tensor(self):
t = torch.rand(H, W, dtype=torch.float32) * (DEPTH_MAX - DEPTH_MIN) + DEPTH_MIN
result = quantize_depth(t, video_backend=None)
assert isinstance(result, np.ndarray)
assert result.dtype == np.uint16
def test_quantize_squeezes_channel_dim(self):
depth = self._make_depth_metres()
for shape in [(H, W, 1), (1, H, W)]:
reshaped = depth.reshape(shape)
quantized = quantize_depth(reshaped, video_backend=None)
assert quantized.ndim == 2, f"Input shape {shape} should be squeezed to 2D"
def test_quantize_returns_pyav_frame(self):
depth = self._make_depth_metres()
result = quantize_depth(depth, video_backend="pyav")
assert isinstance(result, av.VideoFrame)
def test_dequantize_output_tensor(self):
quantized = np.full((H, W), DEPTH_QMAX // 2, dtype=np.uint16)
result = dequantize_depth(quantized, output_unit="m", output_tensor=True)
assert isinstance(result, torch.Tensor)
assert result.shape == (H, W, 1)
def test_invalid_log_params_raises(self):
depth = np.ones((4, 4), dtype=np.float32)
with pytest.raises(ValueError, match="depth_min \\+ shift must be positive"):
quantize_depth(depth, depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
# ── 2. Image writer depth support ───────────────────────────────────
class TestImageWriterDepth:
"""image_array_to_pil_image and write_image for single-channel depth maps."""
def test_pil_uint16_grayscale(self):
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
img = image_array_to_pil_image(arr)
assert isinstance(img, PIL.Image.Image)
assert img.mode == "I;16"
assert img.size == (W, H)
def test_pil_float32_grayscale(self):
arr = np.random.rand(H, W).astype(np.float32)
img = image_array_to_pil_image(arr)
assert img.mode == "F"
def test_pil_squeeze_hwc1_and_1hw(self):
arr_uint16 = np.zeros((H, W), dtype=np.uint16)
for input_arr in [arr_uint16.reshape(H, W, 1), arr_uint16.reshape(1, H, W)]:
img = image_array_to_pil_image(input_arr)
assert img.size == (W, H)
def test_save_kwargs_png_vs_tiff(self):
png_kw = save_kwargs_for_path(Path("frame.png"), compress_level=5)
assert png_kw == {"compress_level": 5}
tiff_kw = save_kwargs_for_path(Path("frame.tiff"), compress_level=5)
assert tiff_kw == {"compression": "raw"}
assert save_kwargs_for_path(Path("frame.jpg"), compress_level=5) == {}
def test_write_image_tiff_roundtrip(self, tmp_path):
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
fpath = tmp_path / "depth.tiff"
write_image(arr, fpath)
assert fpath.exists()
with PIL.Image.open(fpath) as loaded:
recovered = np.array(loaded)
np.testing.assert_array_equal(recovered, arr)
# ── 3. Feature routing ──────────────────────────────────────────────
class TestHwToDatasetFeaturesDepth:
"""hw_to_dataset_features marks single-channel cameras as depth."""
def test_single_channel_cam_marked_depth(self):
from lerobot.utils.feature_utils import hw_to_dataset_features
features = hw_to_dataset_features({"cam": (480, 640, 1)}, prefix="observation")
ft = features["observation.images.cam"]
assert ft["info"]["is_depth_map"] is True
def test_three_channel_cam_not_depth(self):
from lerobot.utils.feature_utils import hw_to_dataset_features
features = hw_to_dataset_features({"cam": (480, 640, 3)}, prefix="observation")
ft = features["observation.images.cam"]
assert ft["info"]["is_depth_map"] is False
def test_invalid_channel_count_raises(self):
from lerobot.utils.feature_utils import hw_to_dataset_features
with pytest.raises(ValueError, match="Expected a 3-tuple"):
hw_to_dataset_features({"cam": (480, 640, 2)}, prefix="observation")
# ── 4. Video info depth flag ────────────────────────────────────────
class TestVideoInfoDepthFlag:
"""Misc depth-related constants and helpers in video_utils / configs."""
def test_get_pix_fmt_channels_gray(self):
assert get_pix_fmt_channels("gray12le") == 1
assert get_pix_fmt_channels("gray8") == 1
def test_ffv1_in_valid_codecs(self):
assert "ffv1" in VALID_VIDEO_CODECS
# ── 5. Feature-to-file-format routing ───────────────────────────────
def _build_mixed_features(dtype: str) -> dict:
"""Build a feature dict with one RGB camera and one depth camera.
Uses shapes from ``DUMMY_CAMERA_FEATURES`` and ``DUMMY_DEPTH_CAMERA_FEATURES``
defined in ``tests.fixtures.constants``.
"""
rgb_cam = next(iter(DUMMY_CAMERA_FEATURES.values()))
depth_cam = next(iter(DUMMY_DEPTH_CAMERA_FEATURES.values()))
return {
"observation.images.rgb": {"dtype": dtype, **rgb_cam},
"observation.images.depth": {"dtype": dtype, **depth_cam},
**{k: {"dtype": v["dtype"], **v} for k, v in DUMMY_MOTOR_FEATURES.items()},
}
def _make_mixed_frame(features: dict) -> dict:
"""Build a valid frame dict matching the given feature schema."""
frame: dict = {"task": "test task"}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ("image", "video"):
channels = shape[-1]
if channels == 1:
frame[key] = np.random.randint(0, 4095, shape, dtype=np.uint16)
else:
frame[key] = np.random.randint(0, 255, shape, dtype=np.uint8)
else:
frame[key] = np.random.randn(*shape).astype(ft["dtype"])
return frame
class TestFeatureFileRouting:
"""Verify that depth vs RGB features are routed to the correct file format."""
NUM_FRAMES = 5
def test_no_video_depth_tiff_rgb_png(self, tmp_path):
"""Without video encoding: depth -> .tiff, RGB -> .png."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
features = _build_mixed_features(dtype="image")
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=tmp_path / "ds",
use_videos=False,
)
for _ in range(self.NUM_FRAMES):
dataset.add_frame(_make_mixed_frame(features))
buf = dataset.writer.episode_buffer
depth_paths = [Path(p) for p in buf["observation.images.depth"]]
rgb_paths = [Path(p) for p in buf["observation.images.rgb"]]
assert all(p.suffix == ".tiff" for p in depth_paths), "Depth frames should be .tiff"
assert all(p.suffix == ".png" for p in rgb_paths), "RGB frames should be .png"
assert all(p.exists() for p in depth_paths), "Depth TIFF files should exist on disk"
assert all(p.exists() for p in rgb_paths), "RGB PNG files should exist on disk"
dataset.save_episode()
dataset.finalize()
def test_video_depth_uses_depth_encoder(self, tmp_path):
"""With streaming video encoding: depth keys use DepthEncoderConfig, RGB keys do not."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
features = _build_mixed_features(dtype="video")
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=tmp_path / "ds",
use_videos=True,
streaming_encoding=True,
)
assert dataset.writer._streaming_encoder is not None
encoder = dataset.writer._streaming_encoder
for _ in range(self.NUM_FRAMES):
dataset.add_frame(_make_mixed_frame(features))
rgb_thread = encoder._threads["observation.images.rgb"]
depth_thread = encoder._threads["observation.images.depth"]
assert not isinstance(rgb_thread.video_encoder, DepthEncoderConfig)
assert isinstance(depth_thread.video_encoder, DepthEncoderConfig)
assert depth_thread.is_depth is True
assert rgb_thread.is_depth is False
dataset.save_episode()
dataset.finalize()

View File

@@ -94,7 +94,7 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
def test_image_array_to_pil_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1)
with pytest.raises(ValueError, match="Unsupported single-channel image dtype"):
with pytest.raises(NotImplementedError):
image_array_to_pil_image(img_array)

View File

@@ -0,0 +1,173 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
import pyarrow as pa # noqa: E402
from lerobot.datasets import LeRobotDataset # noqa: E402
from lerobot.datasets.io_utils import write_info # noqa: E402
from lerobot.datasets.language import ( # noqa: E402
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
VIEW_DEPENDENT_STYLES,
column_for_style,
is_view_dependent_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
validate_camera_field,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH # noqa: E402
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == [
"role",
"content",
"style",
"timestamp",
"camera",
"tool_calls",
]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
# Persistent-row timestamps use float32, matching LeRobotDataset frame timestamps.
assert persistent_row_type.field("timestamp").type == pa.float32()
def test_validate_feature_language_warns_only_on_non_empty_value(caplog):
from lerobot.datasets.feature_utils import validate_feature_language
# None (the expected record-time value) is silent and non-fatal.
with caplog.at_level("WARNING"):
assert validate_feature_language("language_persistent", None) == ""
assert caplog.records == []
# A stray non-empty value is dropped later, so we warn rather than fail.
with caplog.at_level("WARNING"):
assert validate_feature_language("language_persistent", [{"role": "user"}]) == ""
assert any("language_persistent" in r.message for r in caplog.records)
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_view_dependent_styles():
# motion lives in PERSISTENT_STYLES and is described in robot-frame
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
# (event) and trace (event, pixel-trajectory) carry a camera tag.
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
assert is_view_dependent_style("vqa")
assert is_view_dependent_style("trace")
assert not is_view_dependent_style("motion")
assert not is_view_dependent_style("subtask")
assert not is_view_dependent_style("plan")
assert not is_view_dependent_style("interjection")
assert not is_view_dependent_style(None)
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
validate_camera_field("vqa", "observation.images.top")
validate_camera_field("trace", "observation.images.front")
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("vqa", None)
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("trace", "")
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
validate_camera_field("subtask", None)
validate_camera_field("plan", None)
validate_camera_field("memory", None)
validate_camera_field("motion", None)
validate_camera_field("interjection", None)
validate_camera_field(None, None)
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("subtask", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("motion", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("interjection", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field(None, "observation.images.top")
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"camera": "observation.images.top",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []

View File

@@ -0,0 +1,417 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.datasets.language_render import ( # noqa: E402
EMITTED_AT_TOLERANCE_S,
active_at,
emitted_at,
nth_next,
nth_prev,
render_sample,
)
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"timestamp": timestamp,
"camera": camera,
"tool_calls": tool_calls,
}
def event_row(role, content, style, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"camera": camera,
"tool_calls": tool_calls,
}
PERSISTENT = [
persistent_row("assistant", "plan 0", "plan", 0.0),
persistent_row("assistant", "memory 0", "memory", 0.0),
persistent_row("assistant", "subtask 0", "subtask", 0.0),
persistent_row("assistant", "memory 1", "memory", 1.0),
persistent_row("assistant", "subtask 1", "subtask", 1.0),
]
EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
]
EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant",
None,
None,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
),
]
# Same emission tick, two cameras: triggers per-camera disambiguation in
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
# (vqa, user) + (vqa, assistant) pair per camera.
EVENTS_AT_3_TWO_CAMERAS = [
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
]
def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}'
)
def test_persistent_relative_resolvers_reject_event_styles():
with pytest.raises(ValueError, match="event-only"):
active_at(1.0, persistent=PERSISTENT, style="vqa")
with pytest.raises(ValueError, match="event-only"):
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
def test_nth_prev_and_next():
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
def test_substitution_if_present_multimodal_and_tool_calls():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${task}: ${interjection}"},
],
stream="high_level",
if_present="interjection",
),
MessageTurn(
role="assistant",
content="${plan}",
stream="high_level",
target=True,
tool_calls_from="speech",
),
],
bindings={"plan": "active_at(t, style=plan)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=EVENTS_AT_2,
t=2.0,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
assert rendered["messages"][1]["content"] == "plan 0"
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
assert rendered["message_streams"] == ["high_level", "high_level"]
assert rendered["target_message_indices"] == [1]
def test_exact_event_miss_returns_none_when_target_skips():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
]
)
assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling():
recipe = TrainingRecipe(
blend={
"a": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
],
),
"b": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
],
),
}
)
first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
assert first == second
def test_emitted_at_filters_vqa_by_camera():
top = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.top",
)
wrist = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.wrist",
)
assert top["content"] == '{"count": 3}'
assert wrist["content"] == '{"count": 1}'
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
with pytest.raises(ValueError, match="Ambiguous resolver"):
emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
)
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
return TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
},
messages=[
MessageTurn(
role="user",
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
)
@pytest.mark.parametrize(
("camera", "expected_query", "expected_answer"),
[
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
],
)
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
rendered = render_sample(
recipe=_vqa_subrecipe(camera),
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
assert rendered["messages"][0]["content"][0]["feature"] == camera
assert rendered["messages"][0]["content"][1]["text"] == expected_query
assert rendered["messages"][1]["content"] == expected_answer
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
rephrasings = [
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
# No explicit task override → resolver consults persistent rows.
seen: set[str] = set()
for sample_idx in range(64):
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=sample_idx,
dataset_ctx={"task": "canonical kitchen task"},
)
seen.add(rendered["messages"][0]["content"])
# Every rephrasing should be reachable across enough samples.
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT, # no task_aug rows
events=[],
t=0.0,
sample_idx=0,
dataset_ctx={"task": "clean the kitchen"},
)
assert rendered["messages"][0]["content"] == "clean the kitchen"
def test_resolve_task_explicit_override_beats_rephrasings():
rephrasings = [
persistent_row("user", "rephrased one", "task_aug", 0.0),
persistent_row("user", "rephrased two", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=0,
task="explicit override wins",
dataset_ctx={"task": "canonical"},
)
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
line up with the parquet-stored timestamp.
"""
rows = [persistent_row("assistant", "memo", "memory", 1.0)]
# Half a tolerance window — bit-different float, comfortably inside
inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory")
assert inside is not None and inside["content"] == "memo"
# Just past the window — no match
outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory")
assert outside is None
def test_render_sample_rejects_non_dict_language_rows():
"""``_normalize_rows`` must surface malformed inputs as TypeError.
A pipeline that hands the renderer a non-dict (e.g. a stray string)
is a real upstream bug — silent skipping would let it propagate.
"""
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
with pytest.raises(TypeError, match="must be dictionaries"):
render_sample(
recipe=recipe,
persistent=["not a dict"],
events=[],
t=0.0,
sample_idx=0,
task="x",
)
def test_low_level_branch_renders_active_subtask():
low_level = TrainingRecipe(
blend={
"low": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(
role="user",
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
stream="high_level",
),
MessageTurn(
role="assistant",
content="${subtask}",
stream="low_level",
target=True,
),
],
)
}
)
rendered = render_sample(
recipe=low_level,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
assert rendered["message_streams"][-1] == "low_level"
assert rendered["target_message_indices"] == [1]

View File

@@ -61,7 +61,9 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=fps,
video_encoder=enc_cfg,
vcodec=enc_cfg.vcodec,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
@@ -110,7 +112,9 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=fps,
video_encoder=enc_cfg,
vcodec=enc_cfg.vcodec,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
@@ -142,7 +146,9 @@ class TestCameraEncoderThread:
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=fps,
video_encoder=enc_cfg,
vcodec=enc_cfg.vcodec,
pix_fmt=enc_cfg.pix_fmt,
codec_options=enc_cfg.get_codec_options(as_strings=True),
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
@@ -385,8 +391,7 @@ class TestStreamingVideoEncoder:
# Verify codec options include thread tuning for libsvtav1 (lp=…)
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
codec_opts = thread.video_encoder.get_codec_options(encoder_threads=thread.encoder_threads)
assert "svtav1-params" in codec_opts or "threads" in codec_opts
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options
# Feed some frames and finish to ensure it works end-to-end
num_frames = 10

View File

@@ -1,193 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask

View File

@@ -0,0 +1,140 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``lerobot.datasets.video_utils.VideoDecoderCache``.
These cover the LRU bounding + file-handle release behaviour added to prevent
unbounded growth when iterating over datasets with many distinct video files
(observed: ~35 GB anon-rss per DataLoader worker on an 8 k-file dataset).
"""
import shutil
from pathlib import Path
import pytest
pytest.importorskip("torchcodec", reason="torchcodec is required (install lerobot[dataset])")
from lerobot.datasets.video_utils import VideoDecoderCache # noqa: E402
TEST_ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts" / "encoded_videos"
SRC_CLIP = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
def _make_distinct_clips(tmp_path: Path, n: int) -> list[Path]:
"""Copy the small reference mp4 to ``n`` distinct paths.
The cache keys on absolute path, so distinct paths force distinct cache entries
even though the file contents are identical.
"""
assert SRC_CLIP.exists(), f"missing test artifact {SRC_CLIP}"
paths = []
for i in range(n):
dst = tmp_path / f"clip_{i:04d}.mp4"
shutil.copyfile(SRC_CLIP, dst)
paths.append(dst)
return paths
class TestVideoDecoderCacheBounded:
def test_default_cache_is_bounded(self):
"""The default cache must have a finite ``max_size`` to bound RSS growth."""
cache = VideoDecoderCache()
assert cache.max_size is not None, "default cache must be bounded"
assert cache.max_size > 0
def test_size_capped_at_max_size(self, tmp_path):
"""``get_decoder`` for >``max_size`` distinct paths must NOT grow without bound."""
paths = _make_distinct_clips(tmp_path, n=5)
cache = VideoDecoderCache(max_size=2)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 2
def test_evicts_least_recently_used(self, tmp_path):
"""Re-accessing an entry must promote it; the LRU entry is the one evicted."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=2)
cache.get_decoder(paths[0])
cache.get_decoder(paths[1])
cache.get_decoder(paths[0]) # promote paths[0] to MRU; paths[1] is now LRU
cache.get_decoder(paths[2]) # should evict paths[1]
assert str(paths[0]) in cache # MRU stays
assert str(paths[1]) not in cache # LRU evicted
assert str(paths[2]) in cache # newest stays
def test_eviction_closes_file_handle(self, tmp_path):
"""Evicting an entry must close its fsspec file handle (otherwise we leak FDs)."""
paths = _make_distinct_clips(tmp_path, n=2)
cache = VideoDecoderCache(max_size=1)
cache.get_decoder(paths[0])
# Reach into the cache to capture the handle before it is evicted. This is
# the only assertion in the suite that touches a private attribute, and it
# is the most direct way to prove the file descriptor is actually released.
evicted_handle = cache._cache[str(paths[0])][1]
assert evicted_handle.closed is False
cache.get_decoder(paths[1]) # forces eviction of paths[0]
assert evicted_handle.closed is True
def test_clear_closes_all_file_handles(self, tmp_path):
"""``clear()`` must close every cached file handle."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=10)
for p in paths:
cache.get_decoder(p)
handles = [entry[1] for entry in cache._cache.values()]
assert all(not h.closed for h in handles)
cache.clear()
assert cache.size() == 0
assert all(h.closed for h in handles)
def test_hit_does_not_reopen_or_evict(self, tmp_path):
"""A cache hit must return the same decoder instance without touching the cap."""
paths = _make_distinct_clips(tmp_path, n=1)
cache = VideoDecoderCache(max_size=2)
first = cache.get_decoder(paths[0])
second = cache.get_decoder(paths[0])
assert first is second
assert cache.size() == 1
def test_unbounded_when_max_size_none(self, tmp_path):
"""``max_size=None`` preserves the legacy unbounded behaviour."""
paths = _make_distinct_clips(tmp_path, n=4)
cache = VideoDecoderCache(max_size=None)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 4
def test_env_var_overrides_default(self, tmp_path, monkeypatch):
"""``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var sets the default ``max_size``."""
monkeypatch.setenv("LEROBOT_VIDEO_DECODER_CACHE_SIZE", "3")
cache = VideoDecoderCache()
assert cache.max_size == 3
paths = _make_distinct_clips(tmp_path, n=5)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 3

View File

@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
import av # noqa: E402
from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig
from lerobot.datasets.image_writer import write_image
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pyav_utils import get_codec
@@ -35,6 +35,7 @@ from lerobot.datasets.video_utils import (
concatenate_video_files,
encode_video_frames,
get_video_info,
reencode_video,
)
from tests.fixtures.constants import DUMMY_VIDEO_INFO
@@ -338,7 +339,7 @@ def _encode_video(
) -> Path:
imgs_dir = path.parent / f"imgs_{path.stem}"
_write_frames(imgs_dir, num_frames=num_frames)
encode_video_frames(imgs_dir, path, fps=fps, video_encoder=cfg, overwrite=True)
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder=cfg, overwrite=True)
return path
@@ -347,16 +348,22 @@ def _read_feature_info(dataset: LeRobotDataset) -> dict:
return info["features"][VIDEO_KEY]["info"]
def _add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
shape = dataset.meta.features[VIDEO_KEY]["shape"]
def _add_frames(dataset: LeRobotDataset, num_frames: int, video_keys: list[str] | None = None) -> None:
from lerobot.utils.constants import DEFAULT_FEATURES
if video_keys is None:
video_keys = dataset.meta.video_keys
for _ in range(num_frames):
dataset.add_frame(
{
VIDEO_KEY: np.random.randint(0, 256, shape, dtype=np.uint8),
"action": np.zeros(2, dtype=np.float32),
"task": "test",
}
)
frame: dict = {"task": "test"}
for key, ft in dataset.meta.features.items():
if key in DEFAULT_FEATURES:
continue
shape = ft["shape"]
if key in video_keys:
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
else:
frame[key] = np.zeros(shape, dtype=np.float32)
dataset.add_frame(frame)
class TestGetVideoInfo:
@@ -368,7 +375,7 @@ class TestGetVideoInfo:
assert info["video.pix_fmt"] == "yuv420p"
assert info["video.fps"] == 30
assert info["video.channels"] == 3
assert info["is_depth_map"] is False
assert info["video.is_depth_map"] is False
assert info["has_audio"] is False
assert "video.g" not in info
assert "video.crf" not in info
@@ -378,7 +385,7 @@ class TestGetVideoInfo:
def test_merges_encoder_config_as_video_prefixed_entries(self):
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
assert info["video.g"] == 2
assert info["video.crf"] == 30
@@ -391,16 +398,11 @@ class TestGetVideoInfo:
def test_stream_derived_keys_take_precedence_over_config(self):
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
assert info["video.codec"] # populated from stream, not from config's vcodec
assert info["video.pix_fmt"] == "yuv420p"
def test_depth_encoder_config_sets_is_depth_map_true(self):
"""A ``DepthEncoderConfig`` causes ``get_video_info`` to mark the stream as depth."""
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=DepthEncoderConfig())
assert info["is_depth_map"] is True
class TestEncodeVideoFrames:
@require_libsvtav1
@@ -459,7 +461,7 @@ class TestEncodeVideoFrames:
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
info = get_video_info(video_path, video_encoder=cfg)
info = get_video_info(video_path, camera_encoder=cfg)
# Stream-derived
assert info["video.height"] == 64
@@ -468,7 +470,7 @@ class TestEncodeVideoFrames:
assert info["video.codec"] == "av1"
assert info["video.pix_fmt"] == "yuv420p"
assert info["video.fps"] == 30
assert info["is_depth_map"] is False
assert info["video.is_depth_map"] is False
assert info["has_audio"] is False
# Encoder config
assert info["video.g"] == 4
@@ -479,6 +481,30 @@ class TestEncodeVideoFrames:
assert info["video.extra_options"] == {}
class TestReencodeVideo:
@require_libsvtav1
@require_h264
def test_reencode_video(self, tmp_path):
src = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
out = tmp_path / "reencoded.mp4"
cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv444p")
reencode_video(src, out, camera_encoder=cfg, overwrite=True)
assert out.exists()
with av.open(str(out)) as container:
n_frames = sum(1 for _ in container.decode(video=0))
assert n_frames == 4
info = get_video_info(out, camera_encoder=cfg)
assert info["video.codec"] == "h264"
assert info["video.pix_fmt"] == "yuv444p"
assert info["video.height"] == 64
assert info["video.width"] == 96
assert info["video.fps"] == 30
assert info["video.g"] == 6
assert info["video.crf"] == 23
class TestConcatenateVideoFiles:
def test_two_clips_frame_count(self, tmp_path):
"""Output frame count equals the sum of the two input frame counts."""

View File

@@ -39,23 +39,12 @@ DUMMY_VIDEO_INFO = {
"video.crf": 30,
"video.preset": 12,
"video.fast_decode": 0,
"is_depth_map": False,
"video.is_depth_map": False,
"has_audio": False,
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
}
DUMMY_DEPTH_VIDEO_INFO = {
**DUMMY_VIDEO_INFO,
"is_depth_map": True,
}
DUMMY_DEPTH_CAMERA_FEATURES = {
"laptop_depth": {
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": DUMMY_DEPTH_VIDEO_INFO,
},
}
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)

View File

@@ -0,0 +1 @@
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""

View File

@@ -0,0 +1,22 @@
from dataclasses import dataclass
@dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
def get_config(variant: str) -> Config:
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
if variant == "dummy":
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
if variant == "gemma_300m":
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
if variant == "gemma_2b":
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
raise ValueError(f"Unknown variant: {variant}")

View File

@@ -0,0 +1,300 @@
from typing import Literal
import torch
from torch import nn
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
class PaliGemmaWithExpertModel(nn.Module):
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Force enable gradient checkpointing if we're in training mode and the model supports it
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
if not self.gemma_expert.model.gradient_checkpointing:
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
self.gemma_expert.model.gradient_checkpointing = True
use_gradient_checkpointing = True
# Debug gradient checkpointing status
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
print(f"Model training mode: {self.training}")
print(
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
)
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
print(
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
)
self._debug_gc_printed = True
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
models = [self.paligemma.model.language_model, self.gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(
layer.input_layernorm, hidden_states, adarms_cond[i]
)
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
self.paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
)
# Old code removed - now using compute_layer_complete function above
# final norm
# Define final norm computation function for gradient checkpointing
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values

View File

@@ -0,0 +1,79 @@
import torch
import torch.nn.functional as F # noqa: N812
def resize_with_pad_torch(
images: torch.Tensor,
height: int,
width: int,
mode: str = "bilinear",
) -> torch.Tensor:
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
Args:
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
height: Target height
width: Target width
mode: Interpolation mode ('bilinear', 'nearest', etc.)
Returns:
Resized and padded tensor with same shape format as input
"""
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
if images.shape[-1] <= 4: # Assume channels-last format
channels_last = True
# Convert to channels-first for torch operations
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
else:
channels_last = False
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
batch_size, channels, cur_height, cur_width = images.shape
# Calculate resize ratio
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
# Resize
resized_images = F.interpolate(
images,
size=(resized_height, resized_width),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
# Calculate padding
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
mode="constant",
value=constant_value,
)
# Convert back to original format if needed
if channels_last:
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
if batch_size == 1 and images.shape[0] == 1:
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
return padded_images

View File

@@ -0,0 +1,471 @@
import copy
import logging
import math
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device):
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
class PI0Pytorch(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pi05 = config.pi05
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True] if self.pi05 else [False, False],
precision=config.dtype,
)
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
if self.pi05:
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
else:
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
torch.set_float32_matmul_precision("high")
if config.pytorch_compile_mode is not None:
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
# The upstream OpenPI module verifies a site-package Transformers patch here.
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def is_gradient_checkpointing_enabled(self):
"""Check if gradient checkpointing is enabled."""
return self.gradient_checkpointing_enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
def _preprocess_observation(self, observation, *, train=True):
"""Helper method to preprocess observation."""
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
return (
list(observation.images.values()),
list(observation.image_masks.values()),
observation.tokenized_prompt,
observation.tokenized_prompt_mask,
observation.state,
)
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# Get batch size from the first dimension of the concatenated tensors
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
if not self.pi05:
if self.state_proj.weight.dtype == torch.float32:
state = state.to(torch.float32)
# Embed state
def state_proj_func(state):
return self.state_proj(state)
state_emb = self._apply_checkpoint(state_proj_func, state)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.action_in_proj.out_features,
min_period=4e-3,
max_period=4.0,
device=timestep.device,
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
if not self.pi05:
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
# Apply MLP layers
def mlp_func(action_time_emb):
x = self.action_time_mlp_in(action_time_emb)
x = F.silu(x) # swish == silu
return self.action_time_mlp_out(x)
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
adarms_cond = None
else:
# time MLP (for adaRMS)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x) # swish == silu
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=True
)
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
# Prepare attention masks
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Apply gradient checkpointing if enabled
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
# Apply gradient checkpointing to final action projection if enabled
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad()
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = observation.state.shape[0]
if noise is None:
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
noise = self.sample_noise(actions_shape, device)
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=False
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step - use new tensor assignment instead of in-place operation
x_t = x_t + dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Prepare attention masks
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)

View File

@@ -0,0 +1,179 @@
import logging
from collections.abc import Sequence
import torch
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
logger = logging.getLogger("openpi")
# Constants moved from model.py
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
IMAGE_RESOLUTION = (224, 224)
def preprocess_observation_pytorch(
observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
):
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
This function avoids complex type annotations that can cause torch.compile issues.
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad_torch(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
image = image / 2.0 + 0.5
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = image.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
image = torch.nn.functional.interpolate(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=image.device)
grid_y = torch.linspace(-1, 1, height, device=image.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
image = torch.nn.functional.grid_sample(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = (
0.7 + torch.rand(1, device=image.device) * 0.6
) # Random factor between 0.7 and 1.3
image = image * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = (
0.6 + torch.rand(1, device=image.device) * 0.8
) # Random factor between 0.6 and 1.4
mean = image.mean(dim=[1, 2, 3], keepdim=True)
image = (image - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = (
0.5 + torch.rand(1, device=image.device) * 1.0
) # Random factor between 0.5 and 1.5
gray = image.mean(dim=-1, keepdim=True)
image = gray + (image - gray) * saturation_factor
# Clamp values to [0, 1]
image = torch.clamp(image, 0, 1)
# Back to [-1, 1]
image = image * 2.0 - 1.0
# Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
else:
out_masks[key] = observation.image_masks[key]
# Create a simple object with the required attributes instead of using the complex Observation class
class SimpleProcessedObservation:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return SimpleProcessedObservation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)

View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi05_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
benchmark_runtime(
eager_model.sample_actions,
compiled_model.sample_actions,
sample_kwargs,
"pi05.sample_actions",
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()

View File

@@ -14,52 +14,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
import gc
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
from lerobot.configs import PreTrainedConfig # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
assert_processor_inputs_match_lerobot,
clone_batch,
deterministic_openpi_forward_preprocess,
fix_reference_state_dict,
fixed_flow_sampling,
load_openpi_reference_state_dict,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
)
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
from lerobot.types import PolicyAction # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COMPILE_MODE = "default"
FORWARD_RTOL = 1e-4
FORWARD_ATOL = 1e-4
SAMPLE_RTOL = 1e-2
SAMPLE_ATOL = 5e-3
DUMMY_DATASET_STATS = {
"observation.state": {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
"action": {
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
@@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = {
}
@pytest.fixture(autouse=True)
def cleanup_cuda_after_test():
yield
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class PI05BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
@@ -96,341 +109,163 @@ class PI05BaseOriginalConfig:
precision: str = "float32"
pi05: bool = True
dtype: str = "float32"
pytorch_compile_mode: str | None = None
def instantiate_lerobot_pi05(
from_pretrained: bool = False,
) -> tuple[
PI05Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
else:
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI05Policy(config)
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
config.device = str(DEVICE)
config.dtype = "float32"
config.compile_model = compile_model
config.compile_mode = COMPILE_MODE
config.gradient_checkpointing = gradient_checkpointing
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi05_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
policy.config.device = str(DEVICE)
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
return policy, preprocessor
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
config = PI05BaseOriginalConfig()
policy = PI0Pytorch(config)
def instantiate_original_pi05():
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名LeRobot 的
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
# loader所以这里手动复用同一份 tensor避免把权重别名差异误判成模型差异。
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
assert missing_keys == []
assert unexpected_keys == []
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
torch.manual_seed(0)
raw_batch = create_dummy_data()
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
openpi_actions = openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
lerobot_pi05,
lerobot_batch,
openpi_observation,
compare_state=False,
)
batch_size = raw_batch[OBS_STATE].shape[0]
noise = torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=DEVICE,
)
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
return lerobot_batch, openpi_observation, openpi_actions, noise, time
class PI05Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description (pi05 processor handles all text formatting)
tasks = batch.get("task", ["Pick up the object"] * batch_size)
if isinstance(tasks, str):
tasks = [tasks] * batch_size
elif len(tasks) == 1:
tasks = tasks * batch_size
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
state = batch["observation.state"]
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
from lerobot.policies.pi05.modeling_pi05 import pad_vector
state = pad_vector(state, DUMMY_STATE_DIM)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create pi05-formatted prompts that include state information
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
compile_model=compile_model,
gradient_checkpointing=gradient_checkpointing,
)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
if gradient_checkpointing:
lerobot_pi05.train()
else:
lerobot_pi05.eval()
original_pi05.eval()
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
with deterministic_openpi_forward_preprocess(original_pi05):
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
openpi_loss = openpi_losses.mean(dim=(1, 2))
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi05_original_vs_lerobot():
"""Test PI05 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi05(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi05.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
original_pi05.eval()
with torch.no_grad():
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
openpi_actions = original_pi05.sample_actions(
device=DEVICE,
observation=openpi_observation,
noise=noise,
num_steps=10,
)
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
def test_pi05_forward_matches_openpi():
assert_forward_matches()
def test_pi05_sample_actions_match_openpi():
assert_sample_actions_match_openpi()
def test_pi05_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(gradient_checkpointing=True)
def test_pi05_compile_forward_matches_openpi():
assert_forward_matches(compile_model=True)
def test_pi05_compile_sample_actions_match_openpi():
assert_sample_actions_match_openpi(compile_model=True)
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(compile_model=True, gradient_checkpointing=True)

View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi0 import PI0Config # noqa: E402
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
"state": torch.randn(1, config.max_state_dim, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi0_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
benchmark_runtime(
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()

Some files were not shown because too many files have changed in this diff Show More