mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
Adds the steerable annotation pipeline (`lerobot-annotate`) that populates the `language_persistent` and `language_events` columns introduced in PR 1 directly into `data/chunk-*/file-*.parquet`. No flavor namespace, no sidecar tree. Modules produced: - Module 1 (plan_subtasks_memory): Pi0.7-style subtasks, plan (init + refresh on interjection), MEM-style memory at subtask boundaries. - Module 2 (interjections_and_speech): t=0 speech-only acknowledgement, mid-episode paired interjection + speech tool-call atom. - Module 3 (general_vqa): bbox/keypoint/count/attribute/spatial pairs at configurable cadence with one-retry JSON validation. Writer enforces: per-episode persistent identity, exact-frame event timestamps, column routing per `column_for_style`, dataset-level `tools` column with the `say` schema, drops legacy `subtask_index`. Validator runs against staged JSONL artifacts before the writer rewrites parquet. Adds `lerobot-annotate` console script, `annotations` extra (datatrove + optional vllm), `make annotation-e2e` opt-in smoke target, and `docs/source/annotation_pipeline.mdx`. Branched from PR 1 (`feat/language-columns`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
272 lines
9.9 KiB
Python
272 lines
9.9 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Pre-write validation against staged outputs.
|
||
|
||
Runs after Modules 1–3 have all written their per-episode artifacts but
|
||
*before* the writer rewrites parquet shards. The validator never touches
|
||
parquet; it only inspects the staging tree and the source frame timestamps
|
||
exposed by :class:`EpisodeRecord`.
|
||
|
||
Checks (per the plan's "Intermediate staging and validation" section):
|
||
|
||
- exact timestamp alignment against source frame timestamps
|
||
- no orphan speech / interjection pairs
|
||
- plan / memory emission consistency (events have a paired persistent row)
|
||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||
attribute / spatial)
|
||
- every row maps to its correct column under :func:`column_for_style`
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from collections.abc import Iterable, Sequence
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from lerobot.datasets.language import (
|
||
LANGUAGE_EVENTS,
|
||
LANGUAGE_PERSISTENT,
|
||
column_for_style,
|
||
)
|
||
|
||
from .reader import EpisodeRecord
|
||
from .staging import EpisodeStaging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ValidationReport:
|
||
"""Outcome of one validation pass across all episodes."""
|
||
|
||
errors: list[str] = field(default_factory=list)
|
||
warnings: list[str] = field(default_factory=list)
|
||
episodes_checked: int = 0
|
||
|
||
@property
|
||
def ok(self) -> bool:
|
||
return not self.errors
|
||
|
||
def add_error(self, message: str) -> None:
|
||
self.errors.append(message)
|
||
|
||
def add_warning(self, message: str) -> None:
|
||
self.warnings.append(message)
|
||
|
||
def summary(self) -> str:
|
||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||
|
||
|
||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||
"bbox": {"detections"},
|
||
"keypoint": {"label", "point_format", "point"},
|
||
"count": {"label", "count"},
|
||
"attribute": {"label", "attribute", "value"},
|
||
"spatial": {"subject", "relation", "object"},
|
||
}
|
||
|
||
|
||
def classify_vqa_answer(payload: Any) -> str | None:
|
||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||
if not isinstance(payload, dict):
|
||
return None
|
||
keys = set(payload.keys())
|
||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||
if required.issubset(keys):
|
||
return kind
|
||
return None
|
||
|
||
|
||
@dataclass
|
||
class StagingValidator:
|
||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||
|
||
timestamp_atol: float = 0.0 # exact-match by default
|
||
|
||
def validate(
|
||
self,
|
||
records: Sequence[EpisodeRecord],
|
||
staging_dir: Path,
|
||
) -> ValidationReport:
|
||
report = ValidationReport()
|
||
for record in records:
|
||
self._validate_episode(record, staging_dir, report)
|
||
report.episodes_checked += 1
|
||
return report
|
||
|
||
def _validate_episode(
|
||
self,
|
||
record: EpisodeRecord,
|
||
staging_dir: Path,
|
||
report: ValidationReport,
|
||
) -> None:
|
||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||
staged = staging.read_all()
|
||
all_rows: list[dict[str, Any]] = []
|
||
for module_name, rows in staged.items():
|
||
for row in rows:
|
||
row = {**row, "_module": module_name}
|
||
all_rows.append(row)
|
||
|
||
frame_ts = set(record.frame_timestamps)
|
||
|
||
events: list[dict[str, Any]] = []
|
||
persistent: list[dict[str, Any]] = []
|
||
for row in all_rows:
|
||
self._check_column_routing(row, report, record.episode_index)
|
||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||
persistent.append(row)
|
||
else:
|
||
events.append(row)
|
||
|
||
for row in events:
|
||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||
|
||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||
self._check_vqa_json(events, report, record.episode_index)
|
||
|
||
def _check_column_routing(
|
||
self,
|
||
row: dict[str, Any],
|
||
report: ValidationReport,
|
||
episode_index: int,
|
||
) -> None:
|
||
style = row.get("style")
|
||
module = row.get("_module")
|
||
try:
|
||
target_col = column_for_style(style)
|
||
except ValueError:
|
||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||
return
|
||
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
|
||
report.add_error(
|
||
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||
)
|
||
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
|
||
report.add_error(
|
||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||
)
|
||
|
||
def _check_event_timestamp_alignment(
|
||
self,
|
||
row: dict[str, Any],
|
||
frame_ts: set[float],
|
||
report: ValidationReport,
|
||
episode_index: int,
|
||
) -> None:
|
||
ts = row.get("timestamp")
|
||
if ts is None:
|
||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||
return
|
||
if self.timestamp_atol == 0.0:
|
||
if float(ts) not in frame_ts:
|
||
report.add_error(
|
||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||
)
|
||
else:
|
||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||
report.add_error(
|
||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||
)
|
||
|
||
def _check_speech_interjection_pairs(
|
||
self,
|
||
events: Iterable[dict[str, Any]],
|
||
report: ValidationReport,
|
||
episode_index: int,
|
||
) -> None:
|
||
speech_ts: dict[float, int] = {}
|
||
interjection_ts: dict[float, int] = {}
|
||
for row in events:
|
||
ts = row.get("timestamp")
|
||
if ts is None:
|
||
continue
|
||
ts_f = float(ts)
|
||
if row.get("style") is None and row.get("role") == "assistant":
|
||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||
if row.get("style") == "interjection":
|
||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||
|
||
for ts in interjection_ts:
|
||
if ts not in speech_ts:
|
||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||
|
||
def _check_plan_memory_consistency(
|
||
self,
|
||
persistent: Sequence[dict[str, Any]],
|
||
events: Sequence[dict[str, Any]],
|
||
report: ValidationReport,
|
||
episode_index: int,
|
||
) -> None:
|
||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||
interjection_ts = sorted(
|
||
{
|
||
float(r["timestamp"])
|
||
for r in events
|
||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||
}
|
||
)
|
||
|
||
if persistent and not plan_ts:
|
||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||
# every interjection should have a same-timestamp plan refresh
|
||
for ts in interjection_ts:
|
||
if ts not in set(plan_ts):
|
||
report.add_error(
|
||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||
)
|
||
# memory should be emitted at subtask boundaries (subset relation)
|
||
if memory_ts and subtask_ts:
|
||
mem_set = set(memory_ts)
|
||
sub_set = set(subtask_ts)
|
||
stray = sorted(mem_set - sub_set)
|
||
if stray:
|
||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||
|
||
def _check_vqa_json(
|
||
self,
|
||
events: Iterable[dict[str, Any]],
|
||
report: ValidationReport,
|
||
episode_index: int,
|
||
) -> None:
|
||
for row in events:
|
||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||
continue
|
||
content = row.get("content")
|
||
if content is None:
|
||
report.add_error(
|
||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||
)
|
||
continue
|
||
try:
|
||
payload = json.loads(content)
|
||
except (TypeError, ValueError) as exc:
|
||
report.add_error(
|
||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||
)
|
||
continue
|
||
shape = classify_vqa_answer(payload)
|
||
if shape is None:
|
||
report.add_error(
|
||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||
)
|