Files
lerobot-clone/src/lerobot/processor/render_messages_processor.py
Pepijn beb22afd81 review: dedupe regex, centralize column names, harden collate, more tests
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
  `recipe.py` and `language_render.py`. Promote to module-level
  `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
  template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
  hardcoded `{"language_persistent", "language_events"}` literals at
  two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
  rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
  silently filtered language fields from samples that didn't have
  them, which would hand downstream consumers a preserved list
  shorter than the tensor batch. Now: if any sample carries a key,
  every sample in the batch must carry it; otherwise raise a
  `ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
  or multi-element list fell through and triggered confusing
  `float([])` errors downstream. Now raises `ValueError` with the
  actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
  of `key = {... if ... else {}}` plus an 11-line splat dict with a
  single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
  comment. Add a docstring explaining it's an intentional extension
  point: downstream modules append project-local styles before
  `column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
  page referenced `src/lerobot/tools/`, `registry.py`, and
  `get_tools(meta)` — none exist in this PR. Added a callout at the
  start of "How to add your own tool" plus a note on the
  implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
  validation.** `test_recipe.py` grew from 1 case to 12 covering:
  blend-or-messages exclusivity, target-turn requirement, blend
  emptiness, weight presence/positivity, nested-blend rejection,
  `from_dict` with nested blends, `from_yaml` / `load_recipe`
  agreement, top-level non-mapping rejection. Added a malformed-row
  test for `_normalize_rows` that asserts non-dict entries raise
  `TypeError`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 19:06:38 +02:00

95 lines
3.7 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0])
return value