mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
Add extensive language support (#3467)
* Add extensive language support * Address review: split persistent/event schemas, drop event timestamps - recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals - dataset_metadata.py: keep CODEBASE_VERSION at v3.0 - language.py: remove RESERVED_STYLES; split arrow/feature schemas into persistent (with timestamp) and event (without timestamp); add docstrings - language_render.py: events use frame-row timestamp implicitly; no per-event timestamp filtering or sorting - converters.py: drop unused subtask_key passthrough - add docstrings to new public APIs (recipe, render_messages_processor, collate) - update tests for split schemas; revert uv.lock Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Add docstrings to all new helpers; revert uv.lock Covers private helpers in recipe.py, language.py, language_render.py, and render_messages_processor.py. Also reverts uv.lock to main (it was re-generated by `uv run` during local checks). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): add motion (persistent) and trace (event-only) styles Promote the previously-reserved motion/trace styles to first-class core styles. motion routes to language_persistent (it tracks robot state over time); trace routes to language_events (single-moment annotations). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): per-camera tagging on view-dependent styles Adds a nullable `camera` field to the language row struct (both persistent and event variants) so view-dependent styles like `vqa` can carry which `observation.images.*` view they were grounded against. Without this, multi-camera datasets ended up with multiple `(vqa, role)` rows at the same timestamp that the resolver could not disambiguate. - `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS, to both Arrow struct types and the HF datasets feature mappings; introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus `is_view_dependent_style` and `validate_camera_field` helpers (camera required iff style is view-dependent). - `language_render.py`: thread an optional `camera=` kwarg through every resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`. Without a `camera` filter, multi-row matches keep raising the existing ambiguity error — which is the desired behaviour on multi-camera data. - `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying the matching image block), keeping the original 0.20 budget and documenting the customization point for datasets with different cameras. - Tests: schema test asserts the new field order; new tests cover `is_view_dependent_style`, `validate_camera_field` (both required and forbidden directions), per-camera `emitted_at` filtering, and the ambiguity error when two cameras emit `(vqa, assistant)` at the same timestamp without a `camera=` filter. RenderMessagesStep + dataset passthrough fixtures updated to include the new field. - `docs/source/language_and_recipes.mdx`: document the `camera` field, the per-camera resolver pattern, and the canonical recipe convention. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): drop motion from VIEW_DEPENDENT_STYLES Motion primitives are described in robot-frame (joint / Cartesian) terms, not pixel space, so they are camera-agnostic. Only `vqa` (event) and `trace` (event, pixel-trajectory) are view-dependent. The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry — the validator, resolver, and HF feature mapping behave identically across the two columns regardless of which styles populate `camera` today — but persistent rows now always have `camera=None` in practice. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): task_aug style + automatic ${task} rephrasing rotation Adds task-prompt diversity (Xiao 2022 / CAST) without touching ``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved ``task_aug`` as a future style; this lands it now. - ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns ``language_persistent`` so PR 2 writers route it correctly. - ``language_render.py``: ``_resolve_task`` now consults the persistent slice for rows of ``style="task_aug", role="user"``. When any exist it picks one deterministically by ``sample_idx`` (blake2b-keyed, not Python's randomized hash) so an epoch sees every rephrasing of every episode while the same sample still resolves identically across reruns. Falls back to the canonical ``meta/tasks.parquet`` task when no rephrasings are present, so existing datasets and unannotated runs keep their behaviour. Explicit ``task=`` overrides still win. - Tests: rephrasing coverage across samples, determinism on repeat ``sample_idx``, fallback when persistent has no ``task_aug`` rows, and explicit override priority. Recipes get this for free: any ``${task}`` placeholder rotates through the available rephrasings. Recipes that want the literal canonical task can override the binding. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so datasets can declare which tools are available (today: just ``say``; tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant fills in for unannotated datasets so chat-template consumers don't have to special-case anything. Three pieces: - ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS`` constants. Single source of truth — PR 2's writer and PR 3's runtime tool registry will both import from here instead of duplicating the dict. - ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``. Returns deep-copied dicts so callers can mutate the result safely. - ``docs/source/tools.mdx``: spec page covering the catalog, per-row invocations, and the three-step "how to add a new tool" workflow (declare schema, implement, register). Linked from the docs toctree under the Datasets section. This lays the groundwork for PR 2's pipeline writing the catalog out during annotation, and PR 3's ``src/lerobot/tools/`` package shipping runnable implementations (one file per tool — first up: ``say.py`` wrapping Kyutai's pocket-tts). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Apply ruff and prettier formatting after merge Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor(language): unify resolver dispatch and prune redundant test scaffolding * Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): always raise on ambiguous resolver matches `_select_one` previously skipped its ambiguity check whenever any of `role`/`tool_name`/`camera` was set, on the assumption that the caller had already pinned down a unique row. That left a real ambiguity hole for VQA: with two cameras emitting `(vqa, assistant)` at the same frame, `emitted_at(..., role="assistant")` silently picked the first sorted row instead of telling the recipe to add `camera=...`. The existing `test_emitted_at_raises_on_ambiguous_per_camera_vqa` test already encoded the desired behavior. Tighten the check: any time `len(rows) > 1` we now raise with the selectors echoed back, so users see exactly which fields they passed and that more is needed to disambiguate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore: fix CI — collapse short ValueError to one line, refresh uv.lock * `ruff format` on CI (newer version) wants the short `camera=None` ValueError on a single line. * `uv.lock` was stale relative to `pyproject.toml`'s `datasets>=4.7.0` pin (and picked up upstream `s390x` marker fixes for cuda packages). CI runs `uv sync --locked` which rejected the divergence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): keep base install green — drop processor re-export, gate dataset-extra tests `lerobot.processor` re-exported `RenderMessagesStep` at the package level, so importing anything from `lerobot.processor` pulled in `lerobot.datasets.language` → `lerobot.datasets/__init__.py` → `require_package("datasets")`, which fails in the Tier 1 base install that intentionally omits the `[dataset]` extra. The chain bricked collection for unrelated suites (`tests/policies/pi0_pi05/...`, `tests/envs/...`, etc.). * Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The only consumer (the test) already imports from the submodule. Document the deliberate omission in the module docstring. * Add `pytest.importorskip("datasets", ...)` (and `pandas` where needed) at the top of the four PR-added tests that exercise the language stack: - tests/datasets/test_language.py - tests/datasets/test_language_render.py - tests/processor/test_render_messages_processor.py - tests/utils/test_collate.py Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): address review — tools accessor, motion docs, conditional collate * **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo` had no `tools` field, so `from_dict` silently dropped the key (it warned about unknown fields then discarded them) and the property always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None` to the dataclass; `to_dict()` drops it when unset so existing datasets keep a clean `info.json`. Fixed the accessor to read `self.info.tools` (the previous `.get(...)` would have raised AttributeError on the dataclass anyway). Added regression tests: fallback when absent, round-trip from disk, and round-trip through `DatasetInfo.from_dict` / `to_dict`. * **`motion` is not view-dependent — fix the docs.** The mdx claimed rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES = {"vqa", "trace"}` and the validator agrees: motion primitives are joint/Cartesian-frame, not pixel-space. Updated both call-out paragraphs in `language_and_recipes.mdx`. * **Conditional `collate_fn` swap.** Added `meta.has_language_columns` and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it, so non-language datasets keep PyTorch's `default_collate`. Also added a pass-through test in `test_collate.py` that asserts on a plain tensor batch the custom collate matches `default_collate` key-for-key, plus a test for the `None`-sample drop path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: dedupe regex, centralize column names, harden collate, more tests * **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in `recipe.py` and `language_render.py`. Promote to module-level `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares template syntax) and import from `language_render.py`. * **#3 — centralize language column names.** `io_utils.py` had hardcoded `{"language_persistent", "language_events"}` literals at two sites. Replace with `LANGUAGE_COLUMNS` import so a future column rename can't silently desync. * **#4 — defensive collate preserved-keys.** `lerobot_collate_fn` silently filtered language fields from samples that didn't have them, which would hand downstream consumers a preserved list shorter than the tensor batch. Now: if any sample carries a key, every sample in the batch must carry it; otherwise raise a `ValueError` so the upstream rendering bug surfaces at the boundary. * **#5 — `_scalar` rejects non-singleton lists.** Previously a zero- or multi-element list fell through and triggered confusing `float([])` errors downstream. Now raises `ValueError` with the actual length. * **#6 — refactor `_extract_complementary_data`.** Replace 11 lines of `key = {... if ... else {}}` plus an 11-line splat dict with a single `_COMPLEMENTARY_KEYS` tuple iterated once. * **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no comment. Add a docstring explaining it's an intentional extension point: downstream modules append project-local styles before `column_for_style` is called. * **#9 — `tools.mdx` notes the runtime layer is future work.** The page referenced `src/lerobot/tools/`, `registry.py`, and `get_tools(meta)` — none exist in this PR. Added a callout at the start of "How to add your own tool" plus a note on the implementations paragraph. * **#10 — tests for YAML round-trip, malformed rows, blend validation.** `test_recipe.py` grew from 1 case to 12 covering: blend-or-messages exclusivity, target-turn requirement, blend emptiness, weight presence/positivity, nested-blend rejection, `from_dict` with nested blends, `from_yaml` / `load_recipe` agreement, top-level non-mapping rejection. Added a malformed-row test for `_normalize_rows` that asserts non-dict entries raise `TypeError`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction * **Float tolerance in `emitted_at` for persistent styles.** The ``_timestamp(row) == t`` exact-equality check silently missed any caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``) even though the parquet timestamp would only differ by ULPs. Added ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance`` instead, with a docstring explaining why exact equality wasn't enough and why 0.1 s is safe at typical 30–100 Hz control rates. Test asserts the new behavior at half-window (matches) and double-window (no match) using the constant so it stays in sync. * **`MessageTurn.stream` is required at construction.** It was typed ``MessageStream | None = None`` so YAML could omit ``stream:`` and pass the dataclass invariant — but ``_validate_rendered`` rejected ``None`` streams later, surfacing the error at the first sample instead of at recipe load. Now ``__post_init__`` raises ``ValueError`` if ``stream`` is ``None``, with the list of valid streams in the message. The redundant late-stage check in ``_validate_rendered`` is replaced with a one-line comment that cites the upstream invariant. Test pins the new construction-time rejection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs(tools): drop follow-up-PR references Reword the two callouts in `tools.mdx` to describe the runtime layer in present tense ("not part of the catalog layer shipped today", "those modules don't yet exist in the tree") instead of pointing at a specific follow-up PR. Keeps the doc honest about what works now without coupling it to a particular release order. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: address CarolinePascal feedback - language timestamps: float64 -> float32 to match LeRobotDataset frame timestamps (Arrow struct + HF feature) - dataset_metadata: hoist `.language` imports to module top — language.py has no lerobot imports, so there is no circular-import risk - dataset_metadata: add a `meta.tools` setter that persists the catalog to info.json and reloads `meta.info` - feature_utils: validate the `language` dtype instead of returning "" — warn (non-fatal) when a non-empty value is written at record time - centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`, shared by render_messages_processor and language_render - docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections, which describe recipe bindings rather than dataset layout - language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change on a human-action timescale, not the camera frame rate Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -49,9 +50,12 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
|
||||
206
src/lerobot/configs/recipe.py
Normal file
206
src/lerobot/configs/recipe.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -37,6 +37,14 @@ from .dataset_tools import (
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
@@ -54,10 +62,15 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"check_video_encoder_parameters_pyav",
|
||||
@@ -69,6 +82,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -36,12 +36,12 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
check_version_compatibility,
|
||||
@@ -177,7 +177,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -343,6 +342,49 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@tools.setter
|
||||
def tools(self, value: list[dict] | None) -> None:
|
||||
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
|
||||
|
||||
Writes ``value`` into the on-disk ``info.json`` (or clears the
|
||||
``tools`` key when ``value`` is ``None`` or empty), then reloads
|
||||
``self.info`` so the in-memory metadata matches what's on disk.
|
||||
Saves callers from hand-editing ``info.json`` and re-instantiating
|
||||
the metadata object.
|
||||
"""
|
||||
self.info.tools = [dict(t) for t in value] if value else None
|
||||
write_info(self.info, self.root)
|
||||
self.info = load_info(self.root)
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -671,7 +713,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
@@ -23,6 +24,12 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -47,7 +54,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -278,6 +291,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return validate_feature_language(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
@@ -357,6 +372,30 @@ def validate_feature_string(name: str, value: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def validate_feature_language(name: str, value) -> str:
|
||||
"""Validate a feature that is expected to hold language annotations.
|
||||
|
||||
Language columns (``language_persistent`` / ``language_events``) are
|
||||
populated after recording by the annotation pipeline, not at record time.
|
||||
Any value supplied here is dropped before the frame is written, so a
|
||||
non-empty value almost certainly signals a mistake. We warn rather than
|
||||
fail to keep recording resilient.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
str: Always an empty string — language values are non-fatal.
|
||||
"""
|
||||
if value is not None:
|
||||
logging.warning(
|
||||
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
|
||||
f"not at record time. The provided value will be dropped."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -186,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -265,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -304,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
242
src/lerobot/datasets/language.py
Normal file
242
src/lerobot/datasets/language.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
|
||||
uses for frame data.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float32(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float32"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
545
src/lerobot/datasets/language_render.py
Normal file
545
src/lerobot/datasets/language_render.py
Normal file
@@ -0,0 +1,545 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
|
||||
is also a float32 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz). This does mean two persistent
|
||||
rows of the same selector emitted within 0.1 s of each other cannot be
|
||||
told apart by ``emitted_at`` — acceptable because persistent annotations
|
||||
(subtask / plan / memory transitions) change on a human-action timescale,
|
||||
not at the camera frame rate."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
return float(unwrap_scalar(row["timestamp"]))
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -88,7 +88,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -130,6 +129,9 @@ class DatasetInfo:
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
@@ -151,11 +153,15 @@ class DatasetInfo:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -95,6 +95,13 @@ from .relative_action_processor import (
|
||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
|
||||
# `lerobot.datasets.language`, which requires the `[dataset]` extra
|
||||
# (`datasets`, `pyarrow`). Importing it from the processor package would
|
||||
# break every base-install consumer of `lerobot.processor`. Users that
|
||||
# need it import directly:
|
||||
# from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
|
||||
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
complementary_data["messages"] = [messages]
|
||||
|
||||
if "message_streams" in complementary_data:
|
||||
streams = complementary_data["message_streams"]
|
||||
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||
complementary_data["message_streams"] = [streams]
|
||||
|
||||
if "target_message_indices" in complementary_data:
|
||||
indices = complementary_data["target_message_indices"]
|
||||
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||
complementary_data["target_message_indices"] = [indices]
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -153,26 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
_COMPLEMENTARY_KEYS = (
|
||||
"task",
|
||||
"index",
|
||||
"task_index",
|
||||
"episode_index",
|
||||
"timestamp",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
"messages",
|
||||
"message_streams",
|
||||
"target_message_indices",
|
||||
)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
"""Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||
each only when present in ``batch``.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||
return {**pad_keys, **extras}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
84
src/lerobot/processor/render_messages_processor.py
Normal file
84
src/lerobot/processor/render_messages_processor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
"""Processor step that turns raw language columns into rendered chat messages.
|
||||
|
||||
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||
and replaces the raw columns with the resulting ``messages`` /
|
||||
``message_streams`` / ``target_message_indices`` keys.
|
||||
"""
|
||||
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
sample_idx = complementary_data.get("index", 0)
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=unwrap_scalar(timestamp),
|
||||
sample_idx=int(unwrap_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
@@ -48,6 +48,7 @@ from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -401,6 +402,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
# Only swap in the language-aware collate when the dataset actually
|
||||
# declares language columns; otherwise stay on PyTorch's default
|
||||
# collate so non-language training runs are unaffected.
|
||||
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
@@ -409,6 +414,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
65
src/lerobot/utils/collate.py
Normal file
65
src/lerobot/utils/collate.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
"""Collate function that preserves Python-list and language fields as lists.
|
||||
|
||||
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
|
||||
rendered-message and language fields as plain Python lists, and delegates
|
||||
every other key to PyTorch's ``default_collate``.
|
||||
"""
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
|
||||
# carry `messages` and half don't) is a real bug in the upstream
|
||||
# rendering step — silently filtering would hand downstream consumers a
|
||||
# preserved list shorter than the tensor batch. Raise instead so the
|
||||
# mismatch surfaces at the boundary.
|
||||
preserved: dict[str, list[Any]] = {}
|
||||
for key in _PYTHON_LIST_KEYS:
|
||||
presence = [key in sample for sample in batch]
|
||||
if not any(presence):
|
||||
continue
|
||||
if not all(presence):
|
||||
raise ValueError(
|
||||
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
|
||||
f"every sample in a batch must agree."
|
||||
)
|
||||
preserved[key] = [sample[key] for sample in batch]
|
||||
tensorizable = [
|
||||
{
|
||||
key: value
|
||||
for key, value in sample.items()
|
||||
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
|
||||
}
|
||||
for sample in batch
|
||||
]
|
||||
collated = default_collate(tensorizable)
|
||||
collated.update(preserved)
|
||||
return collated
|
||||
@@ -160,6 +160,25 @@ def has_method(cls: object, method_name: str) -> bool:
|
||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||
|
||||
|
||||
def unwrap_scalar(value: Any) -> Any:
|
||||
"""Unwrap a tensor / numpy scalar / single-element list into a Python scalar.
|
||||
|
||||
Tensors and numpy scalars expose ``.item()``; single-element lists are
|
||||
unwrapped recursively. Anything else is returned unchanged. Centralized
|
||||
here so the language renderer and processor steps share one definition.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``value`` is a list with zero or multiple elements.
|
||||
"""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return unwrap_scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
"""
|
||||
Return True if a given string can be converted to a numpy dtype.
|
||||
|
||||
Reference in New Issue
Block a user