#!/usr/bin/env python # Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Final parquet rewrite. For every episode the writer: 1. reads the staged module outputs, 2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event slice (EVENT_ONLY_STYLES + style=None tool-call atoms), 3. sorts each slice deterministically, 4. broadcasts the persistent slice across every frame in the episode, 5. for each frame, materializes the sublist of event rows whose timestamp exactly equals that frame's timestamp, 6. drops the legacy ``subtask_index`` column, 7. adds a top-level ``tools`` column containing the JSON schema for ``say``, 8. writes the parquet shard back in place. Invariants enforced here (and re-checked by the validator): - per-episode persistent slice is byte-identical across every frame; - ``language_events`` rows on a frame all have ``timestamp == frame_ts`` (timestamps come straight from the source parquet — never recomputed); - every row passes ``column_for_style(style)``. """ from __future__ import annotations import json import logging from collections import defaultdict from collections.abc import Iterable, Sequence from dataclasses import dataclass from pathlib import Path from typing import Any import pyarrow as pa import pyarrow.parquet as pq from lerobot.datasets.language import ( EVENT_ONLY_STYLES, LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, PERSISTENT_STYLES, column_for_style, validate_camera_field, ) from .reader import EpisodeRecord from .staging import EpisodeStaging logger = logging.getLogger(__name__) SAY_TOOL_SCHEMA: dict[str, Any] = { "type": "function", "function": { "name": "say", "description": "Speak a short utterance to the user via the TTS executor.", "parameters": { "type": "object", "properties": { "text": { "type": "string", "description": "The verbatim text to speak.", } }, "required": ["text"], }, }, } def _row_persistent_sort_key(row: dict[str, Any]) -> tuple: return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "") def _row_event_sort_key(row: dict[str, Any]) -> tuple: # events are bucketed per-frame, but within a frame we still want determinism return ( row.get("style") or "", row.get("role") or "", row.get("camera") or "", ) def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]: """Coerce a staged row into the persistent column's struct shape.""" style = row.get("style") if style not in PERSISTENT_STYLES: raise ValueError( f"persistent slice contains row with non-persistent style {style!r}; " "row would be misrouted under column_for_style()" ) if "timestamp" not in row: raise ValueError(f"persistent row missing timestamp: {row!r}") camera = row.get("camera") validate_camera_field(style, camera) return { "role": str(row["role"]), "content": None if row.get("content") is None else str(row["content"]), "style": style, "timestamp": float(row["timestamp"]), "camera": None if camera is None else str(camera), "tool_calls": _normalize_tool_calls(row.get("tool_calls")), } def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]: """Coerce a staged row into the event column's struct shape (no timestamp).""" style = row.get("style") if style is not None and style not in EVENT_ONLY_STYLES: raise ValueError( f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}" ) if column_for_style(style) != LANGUAGE_EVENTS: raise ValueError(f"event row with style {style!r} would not route to language_events") camera = row.get("camera") validate_camera_field(style, camera) return { "role": str(row["role"]), "content": None if row.get("content") is None else str(row["content"]), "style": style, "camera": None if camera is None else str(camera), "tool_calls": _normalize_tool_calls(row.get("tool_calls")), } def _normalize_tool_calls(value: Any) -> list[Any] | None: if value is None: return None if not isinstance(value, list): raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}") return list(value) def _validate_atom_invariants(row: dict[str, Any]) -> None: """At-least-one of content/tool_calls; style=None implies tool_calls.""" has_content = row.get("content") is not None has_tools = row.get("tool_calls") is not None if not (has_content or has_tools): raise ValueError(f"row has neither content nor tool_calls: {row!r}") if row.get("style") is None and not has_tools: raise ValueError(f"style=None requires tool_calls: {row!r}") def _validate_speech_atom(row: dict[str, Any]) -> None: """Speech atoms: role=assistant, style=None, content=None, say tool call.""" if row.get("style") is not None: return # not a speech atom if row.get("role") != "assistant": raise ValueError(f"speech atom must have role=assistant: {row!r}") if row.get("content") is not None: raise ValueError(f"speech atom must have content=null: {row!r}") tool_calls = row.get("tool_calls") if not tool_calls or not isinstance(tool_calls, list): raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}") first = tool_calls[0] if not isinstance(first, dict): raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}") if first.get("type") != "function": raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}") fn = first.get("function") or {} if fn.get("name") != "say": raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}") args = fn.get("arguments") or {} if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str): raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}") @dataclass class LanguageColumnsWriter: """Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns.""" drop_existing_subtask_index: bool = True def write_all( self, records: Sequence[EpisodeRecord], staging_dir: Path, root: Path, ) -> list[Path]: episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list) for record in records: episodes_by_path[record.data_path].append(record) written: list[Path] = [] for path, eps in episodes_by_path.items(): self._rewrite_one(path, eps, staging_dir, root) written.append(path) return written def _rewrite_one( self, path: Path, episodes: Sequence[EpisodeRecord], staging_dir: Path, root: Path, ) -> None: table = pq.read_table(path) n_rows = table.num_rows # Ensure we cover every episode in the file. Episodes that don't have # staging artifacts are passed through with empty annotation lists — # this keeps the writer idempotent and safe for partial reruns. staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {} for record in episodes: staging = EpisodeStaging(staging_dir, record.episode_index) staged_per_ep[record.episode_index] = staging.read_all() persistent_by_ep: dict[int, list[dict[str, Any]]] = {} events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {} for ep_index, ep_staged in staged_per_ep.items(): persistent_rows: list[dict[str, Any]] = [] event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed for _module_name, rows in ep_staged.items(): for row in rows: style = row.get("style") if column_for_style(style) == LANGUAGE_PERSISTENT: persistent_rows.append(row) else: event_rows.append(row) persistent_rows.sort(key=_row_persistent_sort_key) normalized_persistent = [] for r in persistent_rows: _validate_atom_invariants(r) _validate_speech_atom(r) normalized_persistent.append(_normalize_persistent_row(r)) persistent_by_ep[ep_index] = normalized_persistent buckets: dict[float, list[dict[str, Any]]] = defaultdict(list) for r in event_rows: _validate_atom_invariants(r) _validate_speech_atom(r) ts = float(r["timestamp"]) buckets[ts].append(_normalize_event_row(r)) for ts in list(buckets.keys()): buckets[ts].sort(key=_row_event_sort_key) events_by_ep_ts[ep_index] = buckets episode_col = ( table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None ) ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None if episode_col is None or ts_col is None: raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.") per_row_persistent: list[list[dict[str, Any]]] = [] per_row_events: list[list[dict[str, Any]]] = [] for i in range(n_rows): ep = episode_col[i] ts = float(ts_col[i]) per_row_persistent.append(persistent_by_ep.get(ep, [])) buckets = events_by_ep_ts.get(ep, {}) per_row_events.append(buckets.get(ts, [])) new_table = self._materialize_table( table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index ) pq.write_table(new_table, path) def _materialize_table( self, table: pa.Table, persistent: list[list[dict[str, Any]]], events: list[list[dict[str, Any]]], *, drop_old: bool, ) -> pa.Table: cols = [] names = [] for name in table.column_names: if drop_old and name == "subtask_index": continue if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS, "tools"): continue # we'll re-add canonical versions cols.append(table.column(name)) names.append(name) # We let pyarrow infer struct/list schema rather than passing the # canonical type from `lerobot.datasets.language` directly: that type # uses `pa.json_()` for the `tool_calls` element type, which # `pa.array(..., type=...)` cannot materialize from Python lists on # current pyarrow versions. The inferred schema round-trips through # parquet and `LeRobotDataset` correctly — see PR 1's # `tests/datasets/test_language.py` which exercises the same flow. persistent_arr = pa.array(persistent) events_arr = pa.array(events) cols.extend([persistent_arr, events_arr]) names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS]) # Dataset-level tools column. Store the JSON schema as a string per # row (broadcast-identical, parquet dictionary-encodes it) — string # storage avoids requiring pa.json_() on every consumer. tools_json = json.dumps([SAY_TOOL_SCHEMA], sort_keys=True) tools_arr = pa.array([tools_json] * table.num_rows, type=pa.string()) cols.append(tools_arr) names.append("tools") return pa.Table.from_arrays(cols, names=names) def speech_atom(timestamp: float, text: str) -> dict[str, Any]: """Build a canonical speech tool-call atom for the events column.""" return { "role": "assistant", "content": None, "style": None, "timestamp": float(timestamp), "camera": None, "tool_calls": [ { "type": "function", "function": { "name": "say", "arguments": {"text": text}, }, } ], } def normalize_rows_for_writer( rows: Iterable[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Helper used by tests/validators to partition a flat row list into (persistent_rows, event_rows) using ``column_for_style``. """ persistent: list[dict[str, Any]] = [] events: list[dict[str, Any]] = [] for row in rows: if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT: persistent.append(row) else: events.append(row) return persistent, events