diff --git a/docs/source/annotation_pipeline.mdx b/docs/source/annotation_pipeline.mdx index c9eefeb3e..cb2012249 100644 --- a/docs/source/annotation_pipeline.mdx +++ b/docs/source/annotation_pipeline.mdx @@ -136,9 +136,14 @@ short manipulation episodes. ### Which modules run -Each module can be turned off independently to iterate on one at a time: -`--plan.enabled`, `--interjections.enabled`, `--vqa.enabled` (all -`true` by default). +Every module is on by default and can be toggled independently (set to +`false` to skip it, e.g. to iterate on one module at a time): + +| Flag | Default | Turns off | +| ------------------------- | ------- | ----------------------------------- | +| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug | +| `--interjections.enabled` | `true` | interjections + speech atoms | +| `--vqa.enabled` | `true` | the VQA pairs | ### The VLM (`--vlm.*`) diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index d054f9eb5..c76a6acad 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -71,47 +71,16 @@ class PlanSubtasksMemoryModule: effective_task = self._resolve_effective_task(record) # task_aug rows at t=0: phrasings the renderer rotates ${task} through. # Either the structured 5-axis taxonomy (task_aug_axes.enabled) or - # free-form n_task_rephrasings. + # free-form n_task_rephrasings; the effective task is always emitted + # first so the rotation covers the source-of-truth phrasing. t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0 - axes_cfg = self.config.task_aug_axes - if axes_cfg.enabled and effective_task: - variants = self._generate_task_aug_by_axes(effective_task, axes_cfg) - seen: set[str] = set() - ordered = [effective_task, *variants] - for phrasing in ordered: - key = phrasing.strip() - if not key or key in seen: - continue - seen.add(key) - rows.append( - { - "role": "user", - "content": key, - "style": "task_aug", - "timestamp": t0, - "tool_calls": None, - } - ) + variants: list[str] | None = None + if self.config.task_aug_axes.enabled and effective_task: + variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes) elif self.config.n_task_rephrasings > 0 and effective_task: - rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings) - # Include the effective task first so the rotation always covers - # the source-of-truth phrasing, not just synthetic ones. - seen = set() - ordered = [effective_task, *rephrasings] - for phrasing in ordered: - key = phrasing.strip() - if not key or key in seen: - continue - seen.add(key) - rows.append( - { - "role": "user", - "content": key, - "style": "task_aug", - "timestamp": t0, - "tool_calls": None, - } - ) + variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings) + if variants is not None: + rows.extend(self._task_aug_rows([effective_task, *variants], t0)) subtask_spans = self._generate_subtasks(record, task=effective_task) @@ -233,6 +202,21 @@ class PlanSubtasksMemoryModule: return True return task.lower() in self._PLACEHOLDER_TASKS + @staticmethod + def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]: + """Build deduplicated ``task_aug`` rows (role=user) at ``t0``.""" + seen: set[str] = set() + rows: list[dict[str, Any]] = [] + for phrasing in phrasings: + key = phrasing.strip() + if not key or key in seen: + continue + seen.add(key) + rows.append( + {"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None} + ) + return rows + # ------------------------------------------------------------------ # VLM call helpers — every plan-module prompt follows the same shape: # build messages → single VLM call → pull a named field. diff --git a/src/lerobot/annotations/steerable_pipeline/writer.py b/src/lerobot/annotations/steerable_pipeline/writer.py index 6710d08bd..e1a544c80 100644 --- a/src/lerobot/annotations/steerable_pipeline/writer.py +++ b/src/lerobot/annotations/steerable_pipeline/writer.py @@ -89,6 +89,27 @@ def _row_event_sort_key(row: dict[str, Any]) -> tuple: ) +def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]: + """Coerce a staged row into the language-column struct shape. + + Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the + writer infers the parquet struct schema from insertion order, so + ``timestamp`` (persistent rows only) sits between ``style`` and ``camera``. + """ + camera = row.get("camera") + validate_camera_field(style, camera) + out: dict[str, Any] = { + "role": str(row["role"]), + "content": None if row.get("content") is None else str(row["content"]), + "style": style, + } + if with_timestamp: + out["timestamp"] = float(row["timestamp"]) + out["camera"] = None if camera is None else str(camera) + out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls")) + return out + + def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]: """Coerce a staged row into the persistent column's struct shape.""" style = row.get("style") @@ -100,22 +121,10 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]: if "timestamp" not in row: raise ValueError(f"persistent row missing timestamp: {row!r}") if "role" not in row: - # Surface a friendly error from the writer rather than letting - # the raw KeyError bubble out of the dict access below — modules - # are expected to always emit ``role``, but the validator - # currently doesn't check this so a future bug would otherwise - # be hard to triage. + # Friendly error from the writer instead of a raw KeyError below; + # the validator doesn't check ``role`` yet. raise ValueError(f"persistent row missing role: {row!r}") - camera = row.get("camera") - validate_camera_field(style, camera) - return { - "role": str(row["role"]), - "content": None if row.get("content") is None else str(row["content"]), - "style": style, - "timestamp": float(row["timestamp"]), - "camera": None if camera is None else str(camera), - "tool_calls": _normalize_tool_calls(row.get("tool_calls")), - } + return _normalize_row(row, style, with_timestamp=True) def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]: @@ -129,15 +138,7 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]: raise ValueError(f"event row with style {style!r} would not route to language_events") if "role" not in row: raise ValueError(f"event row missing role: {row!r}") - camera = row.get("camera") - validate_camera_field(style, camera) - return { - "role": str(row["role"]), - "content": None if row.get("content") is None else str(row["content"]), - "style": style, - "camera": None if camera is None else str(camera), - "tool_calls": _normalize_tool_calls(row.get("tool_calls")), - } + return _normalize_row(row, style, with_timestamp=False) def _normalize_tool_calls(value: Any) -> list[Any] | None: