mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
* Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
530 lines
18 KiB
Python
530 lines
18 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import hashlib
|
|
import re
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
|
|
|
from .language import LANGUAGE_PERSISTENT, column_for_style
|
|
|
|
LanguageRow = dict[str, Any]
|
|
RenderedMessages = dict[str, list[Any]]
|
|
|
|
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
|
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
|
|
|
|
|
def active_at(
|
|
t: float,
|
|
*,
|
|
persistent: Sequence[LanguageRow],
|
|
style: str | None = None,
|
|
role: str | None = None,
|
|
tool_name: str | None = None,
|
|
camera: str | None = None,
|
|
) -> LanguageRow | None:
|
|
"""Return the persistent row of ``style`` that is active at time ``t``.
|
|
|
|
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
|
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
|
``camera`` selector. Only valid for persistent styles.
|
|
"""
|
|
_validate_persistent_resolver("active_at", style)
|
|
matches = [
|
|
row
|
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
|
if _timestamp(row) <= t
|
|
]
|
|
if not matches:
|
|
return None
|
|
latest_ts = max(_timestamp(row) for row in matches)
|
|
return _select_one(
|
|
[row for row in matches if _timestamp(row) == latest_ts],
|
|
style=style,
|
|
role=role,
|
|
tool_name=tool_name,
|
|
camera=camera,
|
|
)
|
|
|
|
|
|
def emitted_at(
|
|
t: float,
|
|
*,
|
|
persistent: Sequence[LanguageRow],
|
|
events: Sequence[LanguageRow],
|
|
style: str | None = None,
|
|
role: str | None = None,
|
|
tool_name: str | None = None,
|
|
camera: str | None = None,
|
|
) -> LanguageRow | None:
|
|
"""Return the row of ``style`` emitted at exactly time ``t``.
|
|
|
|
For persistent styles, this matches persistent rows whose own ``timestamp``
|
|
equals ``t``. For event styles, the ``events`` list is assumed to come from
|
|
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
|
|
so all matching event rows are considered emitted at ``t``. ``camera``
|
|
filters by the row's ``camera`` field — required to disambiguate when
|
|
multiple view-dependent rows share ``(t, role)`` across cameras.
|
|
"""
|
|
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
|
matches = [
|
|
row
|
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
|
if _timestamp(row) == t
|
|
]
|
|
else:
|
|
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
|
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
|
|
|
|
|
def nth_prev(
|
|
t: float,
|
|
*,
|
|
persistent: Sequence[LanguageRow],
|
|
style: str | None = None,
|
|
offset: int = 1,
|
|
role: str | None = None,
|
|
tool_name: str | None = None,
|
|
camera: str | None = None,
|
|
) -> LanguageRow | None:
|
|
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
|
|
|
Walks back through chronologically sorted persistent rows of ``style``
|
|
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
|
one ``offset`` positions before the row active at ``t``. Only valid for
|
|
persistent styles.
|
|
"""
|
|
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
|
|
|
|
|
def nth_next(
|
|
t: float,
|
|
*,
|
|
persistent: Sequence[LanguageRow],
|
|
style: str | None = None,
|
|
offset: int = 1,
|
|
role: str | None = None,
|
|
tool_name: str | None = None,
|
|
camera: str | None = None,
|
|
) -> LanguageRow | None:
|
|
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
|
|
|
Walks forward through chronologically sorted persistent rows of ``style``
|
|
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
|
one ``offset`` positions after the row active at ``t``. Only valid for
|
|
persistent styles.
|
|
"""
|
|
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
|
|
|
|
|
def render_sample(
|
|
*,
|
|
recipe: TrainingRecipe,
|
|
persistent: Sequence[LanguageRow] | None,
|
|
events: Sequence[LanguageRow] | None,
|
|
t: float,
|
|
sample_idx: int,
|
|
task: str | None = None,
|
|
dataset_ctx: Any | None = None,
|
|
) -> RenderedMessages | None:
|
|
"""Render the chat-style messages for a single dataset sample.
|
|
|
|
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
|
at frame timestamp ``t``, then expands the recipe's message templates.
|
|
Returns ``None`` if the resolved sample contains no target message.
|
|
"""
|
|
persistent_rows = _normalize_rows(persistent or [])
|
|
event_rows = _normalize_rows(events or [])
|
|
selected_recipe = _select_recipe(recipe, sample_idx)
|
|
bindings = _resolve_bindings(
|
|
selected_recipe,
|
|
persistent=persistent_rows,
|
|
events=event_rows,
|
|
t=t,
|
|
sample_idx=sample_idx,
|
|
task=task,
|
|
dataset_ctx=dataset_ctx,
|
|
)
|
|
return _render_message_recipe(selected_recipe, bindings)
|
|
|
|
|
|
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
|
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
|
if recipe.blend is None:
|
|
return recipe
|
|
|
|
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
|
if total_weight <= 0:
|
|
raise ValueError("Blend weights must sum to a positive value.")
|
|
|
|
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
|
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
|
cumulative = 0.0
|
|
last_component: TrainingRecipe | None = None
|
|
for component in recipe.blend.values():
|
|
last_component = component
|
|
cumulative += component.weight or 0.0
|
|
if draw < cumulative:
|
|
return component
|
|
assert last_component is not None
|
|
return last_component
|
|
|
|
|
|
def _resolve_bindings(
|
|
recipe: TrainingRecipe,
|
|
*,
|
|
persistent: Sequence[LanguageRow],
|
|
events: Sequence[LanguageRow],
|
|
t: float,
|
|
sample_idx: int,
|
|
task: str | None,
|
|
dataset_ctx: Any | None,
|
|
) -> dict[str, LanguageRow | str | None]:
|
|
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
|
bindings: dict[str, LanguageRow | str | None] = {
|
|
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
|
}
|
|
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
|
for name, spec in specs.items():
|
|
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
|
return bindings
|
|
|
|
|
|
def _resolve_task(
|
|
task: str | None,
|
|
dataset_ctx: Any | None,
|
|
*,
|
|
persistent: Sequence[LanguageRow] = (),
|
|
sample_idx: int = 0,
|
|
) -> str | None:
|
|
"""Return the task string for ``sample_idx``.
|
|
|
|
Resolution order:
|
|
|
|
1. Explicit ``task`` override (caller-supplied) wins.
|
|
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
|
deterministically pick one by ``sample_idx`` so each frame of an
|
|
episode rotates through the available rephrasings across an epoch.
|
|
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
|
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
|
in: ``${task}`` automatically picks a rephrasing when one exists,
|
|
and falls back to the canonical task otherwise. Recipes that want
|
|
the literal canonical task can override the binding.
|
|
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
|
backed by ``meta/tasks.parquet``).
|
|
"""
|
|
if task is not None:
|
|
return task
|
|
|
|
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
|
if aug_rows:
|
|
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
|
# rotation is reproducible across runs (Python's built-in ``hash``
|
|
# is process-randomized).
|
|
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
|
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
|
chosen = aug_rows[idx].get("content")
|
|
if chosen:
|
|
return str(chosen)
|
|
|
|
if dataset_ctx is None:
|
|
return None
|
|
if isinstance(dataset_ctx, dict):
|
|
return dataset_ctx.get("task")
|
|
return getattr(dataset_ctx, "task", None)
|
|
|
|
|
|
def _resolve_spec(
|
|
spec: str,
|
|
*,
|
|
persistent: Sequence[LanguageRow],
|
|
events: Sequence[LanguageRow],
|
|
t: float,
|
|
) -> LanguageRow | None:
|
|
"""Parse a single binding's resolver expression and dispatch to its function."""
|
|
match = _RESOLVER_RE.match(spec.strip())
|
|
if match is None:
|
|
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
|
name = match.group("name")
|
|
kwargs = _parse_resolver_args(match.group("args"))
|
|
kwargs.pop("t_arg", None)
|
|
|
|
if name == "emitted_at":
|
|
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
|
if name == "active_at":
|
|
return active_at(t, persistent=persistent, **kwargs)
|
|
if name == "nth_prev":
|
|
return nth_prev(t, persistent=persistent, **kwargs)
|
|
if name == "nth_next":
|
|
return nth_next(t, persistent=persistent, **kwargs)
|
|
raise ValueError(f"Unknown language resolver: {name!r}")
|
|
|
|
|
|
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
|
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
|
kwargs: dict[str, Any] = {}
|
|
if not args.strip():
|
|
return kwargs
|
|
|
|
parts = [part.strip() for part in args.split(",") if part.strip()]
|
|
for part in parts:
|
|
if part == "t":
|
|
kwargs["t_arg"] = True
|
|
continue
|
|
if "=" not in part:
|
|
raise ValueError(f"Invalid resolver argument: {part!r}")
|
|
key, value = (item.strip() for item in part.split("=", 1))
|
|
if key == "offset":
|
|
kwargs[key] = int(value)
|
|
else:
|
|
kwargs[key] = value.strip("\"'")
|
|
return kwargs
|
|
|
|
|
|
def _render_message_recipe(
|
|
recipe: TrainingRecipe,
|
|
bindings: dict[str, LanguageRow | str | None],
|
|
) -> RenderedMessages | None:
|
|
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
|
assert recipe.messages is not None
|
|
messages: list[dict[str, Any]] = []
|
|
streams: list[str | None] = []
|
|
target_indices: list[int] = []
|
|
|
|
for turn in recipe.messages:
|
|
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
|
continue
|
|
|
|
message = {"role": turn.role}
|
|
if turn.content is not None:
|
|
message["content"] = _render_content(turn.content, bindings)
|
|
|
|
if turn.tool_calls_from is not None:
|
|
row = bindings.get(turn.tool_calls_from)
|
|
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
|
if tool_calls:
|
|
message["tool_calls"] = copy.deepcopy(tool_calls)
|
|
|
|
message_idx = len(messages)
|
|
messages.append(message)
|
|
streams.append(turn.stream)
|
|
if turn.target:
|
|
target_indices.append(message_idx)
|
|
|
|
if not target_indices:
|
|
return None
|
|
|
|
rendered = {
|
|
"messages": messages,
|
|
"message_streams": streams,
|
|
"target_message_indices": target_indices,
|
|
}
|
|
_validate_rendered(rendered)
|
|
return rendered
|
|
|
|
|
|
def _render_content(
|
|
content: str | list[dict[str, Any]],
|
|
bindings: dict[str, LanguageRow | str | None],
|
|
) -> str | list[dict[str, Any]]:
|
|
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
|
if isinstance(content, str):
|
|
return _substitute(content, bindings)
|
|
|
|
rendered_blocks = []
|
|
for block in content:
|
|
rendered_block = copy.deepcopy(block)
|
|
for key, value in rendered_block.items():
|
|
if isinstance(value, str):
|
|
rendered_block[key] = _substitute(value, bindings)
|
|
rendered_blocks.append(rendered_block)
|
|
return rendered_blocks
|
|
|
|
|
|
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
|
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
|
|
|
def replace(match: re.Match[str]) -> str:
|
|
"""Resolve a single ``${name}`` match to its bound string value."""
|
|
name = match.group(1)
|
|
if name not in bindings:
|
|
raise ValueError(f"Unknown template binding: {name!r}")
|
|
value = bindings[name]
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, dict):
|
|
content = value.get("content")
|
|
return "" if content is None else str(content)
|
|
return str(value)
|
|
|
|
return _PLACEHOLDER_RE.sub(replace, template)
|
|
|
|
|
|
def _validate_rendered(rendered: RenderedMessages) -> None:
|
|
"""Sanity-check the rendered output for stream/target alignment."""
|
|
messages = rendered["messages"]
|
|
streams = rendered["message_streams"]
|
|
target_indices = rendered["target_message_indices"]
|
|
|
|
if len(streams) != len(messages):
|
|
raise ValueError("message_streams must be aligned with messages.")
|
|
if not target_indices:
|
|
raise ValueError("Rendered samples must contain at least one target message.")
|
|
for idx in target_indices:
|
|
if idx < 0 or idx >= len(messages):
|
|
raise ValueError(f"Target message index {idx} is out of bounds.")
|
|
for idx, stream in enumerate(streams):
|
|
if stream is None:
|
|
raise ValueError(f"Rendered message {idx} has no stream.")
|
|
|
|
|
|
def _nth_relative(
|
|
name: str,
|
|
t: float,
|
|
persistent: Sequence[LanguageRow],
|
|
style: str | None,
|
|
offset: int,
|
|
role: str | None,
|
|
tool_name: str | None,
|
|
camera: str | None,
|
|
) -> LanguageRow | None:
|
|
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
|
_validate_persistent_resolver(name, style)
|
|
if abs(offset) < 1:
|
|
raise ValueError(f"{name} offset must be non-zero.")
|
|
|
|
rows = sorted(
|
|
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
|
key=_row_sort_key,
|
|
)
|
|
if not rows:
|
|
return None
|
|
|
|
anchor_idx = None
|
|
for idx, row in enumerate(rows):
|
|
if _timestamp(row) <= t:
|
|
anchor_idx = idx
|
|
else:
|
|
break
|
|
|
|
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
|
|
|
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
|
return None
|
|
return rows[target_idx]
|
|
|
|
|
|
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
|
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
|
if style is None:
|
|
raise ValueError(f"{name} requires a persistent style.")
|
|
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
|
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
|
|
|
|
|
def _matching_rows(
|
|
rows: Sequence[LanguageRow],
|
|
*,
|
|
style: str | None,
|
|
role: str | None,
|
|
tool_name: str | None,
|
|
camera: str | None,
|
|
) -> list[LanguageRow]:
|
|
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
|
return [
|
|
row
|
|
for row in rows
|
|
if (style is None or row.get("style") == style)
|
|
and (role is None or row.get("role") == role)
|
|
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
|
and (camera is None or row.get("camera") == camera)
|
|
]
|
|
|
|
|
|
def _select_one(
|
|
rows: Sequence[LanguageRow],
|
|
*,
|
|
style: str | None,
|
|
role: str | None,
|
|
tool_name: str | None,
|
|
camera: str | None,
|
|
) -> LanguageRow | None:
|
|
"""Return the single matching row, or raise if the selectors are ambiguous.
|
|
|
|
Ties are broken deterministically by ``_row_sort_key`` so that
|
|
multiple rows with identical ``(style, role, tool_name, camera)`` still
|
|
resolve to a stable choice.
|
|
"""
|
|
if not rows:
|
|
return None
|
|
if len(rows) > 1 and role is None and tool_name is None and camera is None:
|
|
raise ValueError(
|
|
f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
|
|
f"or camera=... to disambiguate."
|
|
)
|
|
return sorted(rows, key=_row_sort_key)[0]
|
|
|
|
|
|
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
|
"""Stable sort key for both persistent and event rows.
|
|
|
|
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
|
to ``0.0`` — within a single frame all event rows share the same sort
|
|
bucket and are tiebroken by ``(style, role)``.
|
|
"""
|
|
timestamp = row.get("timestamp")
|
|
ts = (
|
|
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
|
|
)
|
|
return (ts, row.get("style") or "", row.get("role") or "")
|
|
|
|
|
|
def _timestamp(row: LanguageRow) -> float:
|
|
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
|
value = row["timestamp"]
|
|
return float(value.item() if hasattr(value, "item") else value)
|
|
|
|
|
|
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
|
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
|
for tool_call in row.get("tool_calls") or []:
|
|
if isinstance(tool_call, str):
|
|
continue
|
|
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
|
if isinstance(function, dict) and function.get("name") == tool_name:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
|
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
|
normalized = []
|
|
for row in rows:
|
|
if row is None:
|
|
continue
|
|
if hasattr(row, "as_py"):
|
|
row = row.as_py()
|
|
if not isinstance(row, dict):
|
|
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
|
normalized.append(dict(row))
|
|
return normalized
|