mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
103 Commits
chore/add-
...
feat/langu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e7c0d6aa1 | ||
|
|
920c6ef5a2 | ||
|
|
c37b1fc7d0 | ||
|
|
9020635b14 | ||
|
|
471b2b1b1d | ||
|
|
a15e16c072 | ||
|
|
336af85c09 | ||
|
|
54221ceea2 | ||
|
|
369ab17110 | ||
|
|
86a7edc590 | ||
|
|
a0233f53f4 | ||
|
|
2ea0da2d9f | ||
|
|
134a707c7a | ||
|
|
ce47075d6b | ||
|
|
26013da699 | ||
|
|
f72b28738a | ||
|
|
1bd53cc7da | ||
|
|
7128bb1769 | ||
|
|
31e0c15e55 | ||
|
|
c5676ef1b3 | ||
|
|
9dfc9084e1 | ||
|
|
fd18beb3a1 | ||
|
|
965d42825f | ||
|
|
1238a0cd47 | ||
|
|
53c7641885 | ||
|
|
088c8371df | ||
|
|
3a52a18b0e | ||
|
|
dad2cf1178 | ||
|
|
bce5387e04 | ||
|
|
85576acc29 | ||
|
|
e7e5fca5de | ||
|
|
beb22afd81 | ||
|
|
d55b581ca1 | ||
|
|
24d2ffe3c6 | ||
|
|
789f29aa56 | ||
|
|
a356b12c41 | ||
|
|
e8327b8e62 | ||
|
|
c450298147 | ||
|
|
5c30b14929 | ||
|
|
8fa8323c91 | ||
|
|
73740ecf4b | ||
|
|
1b81e49214 | ||
|
|
d813c75b76 | ||
|
|
3434d2ef22 | ||
|
|
b71e10da6b | ||
|
|
0f6e3230df | ||
|
|
2f2e42c4aa | ||
|
|
5ee0104739 | ||
|
|
e064cfcb04 | ||
|
|
b3d9494831 | ||
|
|
1217fdb6f0 | ||
|
|
d0388e1142 | ||
|
|
524aa59faa | ||
|
|
27f7829b09 | ||
|
|
7f8bf108e8 | ||
|
|
855ff027f8 | ||
|
|
3b797bb118 | ||
|
|
aea04721ae | ||
|
|
ab5479129a | ||
|
|
e6d4ac6f02 | ||
|
|
5722d365c5 | ||
|
|
3d7e60cee4 | ||
|
|
7b767d4d60 | ||
|
|
f1e3ab7794 | ||
|
|
585341ba9f | ||
|
|
23ff346027 | ||
|
|
3c5cbe7af4 | ||
|
|
f2cbd97635 | ||
|
|
c06c8d594a | ||
|
|
cd495a3a9d | ||
|
|
c99ac45cd1 | ||
|
|
13aaafeae0 | ||
|
|
2129648bf4 | ||
|
|
f5cd3f6e4e | ||
|
|
ecf5766301 | ||
|
|
11597d4f71 | ||
|
|
8b9c598cf4 | ||
|
|
b325475b38 | ||
|
|
ef137ff86a | ||
|
|
c5df821a96 | ||
|
|
7ec3d7999c | ||
|
|
712d63abbd | ||
|
|
6653999983 | ||
|
|
4bdbedc9a0 | ||
|
|
e240305e8e | ||
|
|
ccd189b264 | ||
|
|
ef1242bbd4 | ||
|
|
ebf4a04d41 | ||
|
|
4419b4ef1b | ||
|
|
ff06ca82d2 | ||
|
|
fcb01e73eb | ||
|
|
268f8d1f53 | ||
|
|
663fff0ae2 | ||
|
|
9d6af804bf | ||
|
|
f763f85213 | ||
|
|
e3e9374e2c | ||
|
|
c1a0c601e2 | ||
|
|
1ca38d9748 | ||
|
|
5a6aa64570 | ||
|
|
0b06790da0 | ||
|
|
b43dc39ba4 | ||
|
|
2b71221194 | ||
|
|
8833d735a1 |
11
.github/dependabot.yml
vendored
11
.github/dependabot.yml
vendored
@@ -1,11 +0,0 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 7
|
||||
groups:
|
||||
actions:
|
||||
patterns: ["*"]
|
||||
6
Makefile
6
Makefile
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -43,6 +43,8 @@
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
|
||||
198
docs/source/annotation_pipeline.mdx
Normal file
198
docs/source/annotation_pipeline.mdx
Normal file
@@ -0,0 +1,198 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` populates the two language columns introduced by the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — directly into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
A vocabulary-discovery phase derives a small canonical wording, then three
|
||||
modules write into a per-episode staging tree, then a single writer
|
||||
rewrites the data shards in place:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | -------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of canonical task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections`|
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections`|
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
The `plan` module is constrained to a **canonical vocabulary** discovered
|
||||
once per dataset by the `vocabulary` module (phase 0). It watches a few
|
||||
sample episode videos (`--vocabulary.sample_episodes`, default `3`) and
|
||||
asks the VLM to derive a small set of imperative subtask labels and
|
||||
first-person memory milestones that recur across the demos. The VLM
|
||||
picks the right number of entries itself based on what it sees in the
|
||||
clips — short pick-and-place demos get ~6 subtask labels, longer
|
||||
multi-step recipes get more. The result lands at
|
||||
`meta/canonical_vocabulary.json` (human-readable / hand-editable) and
|
||||
is reused on every subsequent run. The `plan` module then constrains
|
||||
both subtask + memory generation to those exact strings — the
|
||||
downstream low-level policy sees a small, repeatable target
|
||||
distribution instead of thousands of LLM paraphrases. Disable with
|
||||
`--vocabulary.enabled=False` to fall back to free-form generation.
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet — the tool
|
||||
catalog lives at `meta/info.json["tools"]` instead (see
|
||||
[Tools](./tools)). After every annotation run the pipeline ensures the
|
||||
canonical `say` schema is present in that list, preserving any tools the
|
||||
user pre-declared.
|
||||
|
||||
If you want to declare additional tools for a dataset before annotation
|
||||
runs, edit `meta/info.json["tools"]` directly — the pipeline preserves
|
||||
anything already there. Implementations of those tools live under
|
||||
`src/lerobot/tools/`; one file per tool, registered via
|
||||
`TOOL_REGISTRY`. See the [Tools](./tools) doc for the authoring guide.
|
||||
|
||||
## Running locally
|
||||
|
||||
Install the extra and invoke the console script. Episode-level
|
||||
concurrency comes from `--executor.episode_parallelism` (default 16);
|
||||
that is the only knob the in-process executor exposes.
|
||||
|
||||
```bash
|
||||
uv sync --extra annotations
|
||||
uv run lerobot-annotate \
|
||||
--root=/path/to/dataset \
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
The pipeline attaches actual camera footage to every `plan` /
|
||||
`interjections` / `vqa` prompt by default, decoded from the dataset's
|
||||
first `observation.images.*` stream. Override with
|
||||
`--vlm.camera_key=observation.images.<name>` to pin a specific
|
||||
viewpoint. Datasets with no video tracks fall back to text-only prompts
|
||||
automatically.
|
||||
|
||||
**The `plan` module sees the whole episode as one video block.** Subtask
|
||||
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||
and decides where to cut. There is no keyframe stride or count knob —
|
||||
`--plan.max_video_frames` (default 128) only caps the frames packed
|
||||
into the video block as a model-capacity bound. The `interjections`
|
||||
module attaches a short window of frames straddling the interjection
|
||||
timestamp. The `vqa` module grounds each VQA pair on a single frame —
|
||||
its `--vqa.K` knob sets how many consecutive frames each emission tick
|
||||
anchors, and every anchored frame gets its own VQA pair on that one
|
||||
frame (there is no per-pair frame window).
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Distributed annotation is delegated to
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). The repo
|
||||
ships a launcher script you copy and edit for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`examples/annotations/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
spawns one `h200x2` job that:
|
||||
|
||||
1. installs the branch under test plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) for the chosen model,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
via `lerobot-annotate`,
|
||||
4. uploads the annotated dataset to `--push_to_hub`.
|
||||
|
||||
To target a different dataset, model, or hub repo, edit the `CMD` block
|
||||
inside the script — every flag in there maps directly onto a CLI flag of
|
||||
`lerobot-annotate` (see `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Style-to-recipe consumer mapping
|
||||
|
||||
The pipeline's outputs are designed to be consumed by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)) — typically:
|
||||
|
||||
- low-level / high-level / memory-update branches consume
|
||||
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||
- An interjection-response branch consumes `interjection` events plus
|
||||
the paired speech atom (merged into one assistant target turn via
|
||||
`tool_calls_from`) and the same-timestamp `plan` refresh.
|
||||
- A VQA branch consumes the `(vqa, user)` and `(vqa, assistant)` pairs
|
||||
from `language_events`.
|
||||
|
||||
## Why the design splits state from events
|
||||
|
||||
Two things drive the scope:
|
||||
|
||||
1. **Persistent state vs exact-event split.** Persistent rows
|
||||
(`subtask`, `plan`, `memory`) broadcast per episode and answer "what
|
||||
state is in force at this frame?". Event rows (`interjection`, `vqa`,
|
||||
speech) only appear on the exact frame whose timestamp matches the
|
||||
emission. The pipeline writes timestamps taken straight from the
|
||||
source parquet — no floating-point recomputation.
|
||||
2. **One Qwen-VL pass.** All three modules share a single VLM client
|
||||
(vLLM if available, transformers fallback) so the cost is one model
|
||||
load per dataset, not three.
|
||||
|
||||
## Module independence and staged reruns
|
||||
|
||||
Each module writes its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||
prompt iteration cheap — re-running one module overwrites only its own
|
||||
JSONL file before the writer composes the final parquet. Modules can be
|
||||
disabled via `--plan.enabled=false` (and likewise `--interjections.enabled`
|
||||
/ `--vqa.enabled`) to
|
||||
test them in isolation.
|
||||
|
||||
## Validation/report checks before final write
|
||||
|
||||
Before the writer runs, `StagingValidator` checks:
|
||||
|
||||
- exact frame-timestamp alignment for every event row;
|
||||
- no orphan speech / interjection pairs;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (warning, not error);
|
||||
- VQA assistant `content` parses as JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row routes to the column dictated by `column_for_style(style)`.
|
||||
|
||||
Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||
|
||||
## Paper inspirations per module
|
||||
|
||||
- **`plan` module — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||
what" detail.
|
||||
- **`plan` module — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
compression directive: keep only minimal relevant information; functional
|
||||
outcomes preserved, specific attributes dropped.
|
||||
- **`interjections` module.** Hi Robot scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||
arguments:{text:...}}}]`).
|
||||
- **`vqa` module.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
|
||||
multi-abstraction grounding. Pi0.7 also grounds answers across
|
||||
multiple abstraction levels.
|
||||
|
||||
Future maintainers should adjust the prompt templates in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
|
||||
references rather than rewriting from scratch.
|
||||
|
||||
## Compute and list-size estimates
|
||||
|
||||
Per episode, the pipeline issues O(`max_steps`) `plan`-module calls,
|
||||
O(`max_interjections_per_episode`) `interjections`-module calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) `vqa`-module calls. With defaults
|
||||
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||
KB at most (parquet dictionary-encodes one entry per episode);
|
||||
`language_events` is empty on most frames and is bounded by the number of
|
||||
emissions, not `num_frames × num_emissions`.
|
||||
|
||||
## Reproducibility via seed and prompt hashes
|
||||
|
||||
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
|
||||
timestamps and VQA question types. Combined with the deterministic prompt
|
||||
templates checked into `prompts/`, two runs at the same seed against the
|
||||
same dataset and the same model checkpoint produce byte-identical staging
|
||||
artifacts. Prompt edits are recorded by file hash; future tooling can pin
|
||||
expected `(seed, prompt_hash)` pairs into the dataset card.
|
||||
89
examples/annotations/run_hf_job.py
Normal file
89
examples/annotations/run_hf_job.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x2`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (phase 0 canonical-vocabulary discovery is
|
||||
disabled — each episode generates its own subtasks + memory),
|
||||
4. uploads the annotated dataset to ``--dest_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Re-enable phase 0 with ``--vocabulary.enabled=true`` (optionally
|
||||
``--vocabulary.sample_episodes=N``) when the dataset is homogeneous
|
||||
enough to share one subtask + memory vocabulary across all episodes.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=imstevenpmwork/super_poulain_draft "
|
||||
"--dest_repo_id=pepijn223/super_poulain_vocab "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=2 "
|
||||
"--vlm.num_gpus=2 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
# Phase 0 — canonical vocabulary discovery DISABLED by default.
|
||||
# Heterogeneous datasets (different tasks/scenes across episodes)
|
||||
# don't share a single small subtask + memory vocabulary, so each
|
||||
# episode generates its subtasks + memory free-form. Flip to
|
||||
# ``--vocabulary.enabled=true`` (optionally ``--vocabulary.sample_episodes=N``)
|
||||
# for homogeneous datasets where a shared canonical vocabulary
|
||||
# helps the downstream policy.
|
||||
"--vocabulary.enabled=false "
|
||||
# Phase 1 — plan module (subtasks + plan + memory + task_aug).
|
||||
"--plan.frames_per_second=1.0 "
|
||||
"--plan.use_video_url=true "
|
||||
"--plan.use_video_url_fps=1.0 "
|
||||
"--plan.derive_task_from_video=always "
|
||||
"--plan.n_task_rephrasings=30 "
|
||||
# Phase 2 — interjections + speech.
|
||||
"--interjections.max_interjections_per_episode=6 "
|
||||
# Phase 4 — general VQA.
|
||||
"--vqa.K=3 "
|
||||
"--vqa.vqa_emission_hz=1.0"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200x2",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
@@ -219,6 +219,18 @@ hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). vllm is the preferred backend
|
||||
# on Linux, with a transformers fallback elsewhere; openai is the default
|
||||
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted endpoints). Distributed execution is
|
||||
# delegated to Hugging Face Jobs (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -309,6 +321,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
@@ -327,7 +340,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
15
src/lerobot/annotations/__init__.py
Normal file
15
src/lerobot/annotations/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
50
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
50
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .vocabulary import (
|
||||
VOCABULARY_FILENAME,
|
||||
Vocabulary,
|
||||
VocabularyDiscoveryModule,
|
||||
load_vocabulary,
|
||||
save_vocabulary,
|
||||
vocabulary_path,
|
||||
)
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"VOCABULARY_FILENAME",
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
"Vocabulary",
|
||||
"VocabularyDiscoveryModule",
|
||||
"load_vocabulary",
|
||||
"save_vocabulary",
|
||||
"vocabulary_path",
|
||||
]
|
||||
251
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
251
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
@@ -0,0 +1,251 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyConfig:
|
||||
"""Phase 0 — dataset-level canonical vocabulary discovery.
|
||||
|
||||
Watches the first ``sample_episodes`` episode videos and asks the VLM
|
||||
to derive a small canonical vocabulary (subtask labels + memory
|
||||
milestones) that every episode in the dataset will reuse. The VLM
|
||||
decides the count itself from what it sees in the clips — short
|
||||
pick-and-place demos get ~6 labels, longer multi-step recipes more.
|
||||
The output lands at ``meta/canonical_vocabulary.json`` and feeds
|
||||
phase 1's subtask + memory generation as both a prompt-side
|
||||
constraint and a post-VLM validation gate.
|
||||
|
||||
Why this exists: free-form LLM rephrasing per episode produces near-
|
||||
unique subtask strings, which makes the downstream low-level policy's
|
||||
conditioning effectively noise — at inference the policy generates a
|
||||
*new* paraphrase the action expert has never seen and produces tiny
|
||||
cautious actions. Forcing every episode onto the same small set of
|
||||
canonical strings gives the action expert dense supervision per
|
||||
string and a small target distribution to learn against.
|
||||
|
||||
Set ``enabled=False`` to fall back to free-form generation (original
|
||||
behaviour). ``reuse_existing=True`` keeps a hand-edited vocabulary
|
||||
file from being clobbered on re-runs.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
sample_episodes: int = 3
|
||||
max_video_frames_per_episode: int = 32
|
||||
# When True (default), an existing meta/canonical_vocabulary.json is
|
||||
# loaded as-is and no VLM call is made — lets operators hand-edit the
|
||||
# file. Set False to always rediscover from the sample episodes.
|
||||
reuse_existing: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: plan + subtasks + memory + task augmentation.
|
||||
|
||||
The ``plan`` module attaches the whole episode as one Qwen-VL video
|
||||
block; ``max_video_frames`` only caps the frames packed in (a
|
||||
model-capacity bound, not an annotation-logic knob).
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Number of ``task_aug`` rephrasings emitted at ``t=0``. The renderer's
|
||||
# ``${task}`` binding rotates among them per ``sample_idx``. ``0`` disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# When to derive the task from the video instead of using
|
||||
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
|
||||
# missing canonical task), or ``always``. The derived task replaces the
|
||||
# canonical one for every ``plan``-module prompt; ``meta/tasks.parquet``
|
||||
# is never modified.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# Frame sampling for the subtask-decomposition prompt.
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 128
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# When True (and backend supports it, e.g. ``openai``), the ``plan``
|
||||
# module sends a ``video_url`` block pointing at a per-episode mp4
|
||||
# subclip and lets the server sample frames at ``use_video_url_fps``.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each interjection emits a paired ``(interjection, speech)`` event row
|
||||
# and triggers a ``plan`` refresh at the same timestamp via the
|
||||
# ``plan`` module.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Visual context attached to the interjection prompt: a short window
|
||||
# of frames centered on the chosen timestamp so the VLM sees the
|
||||
# ongoing motion rather than a single frozen frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""How many *consecutive* frames each emission tick anchors a VQA pair
|
||||
to. The VLM grounds its answer (bbox / keypoint coordinates, count, …)
|
||||
against the *first* anchored frame's image, so anchoring K>1 frames
|
||||
copies that same answer onto later frames where the scene has already
|
||||
moved — stale labels. Default ``1``: a VQA pair lands on exactly its
|
||||
emission frame, no temporal smear. Raise it only to trade label
|
||||
precision for more (noisier) VQA frames."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests).
|
||||
# ``openai`` talks to a local OpenAI-compatible server; the CLI
|
||||
# auto-spawns one when ``auto_serve=True``.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-35B-A3B-FP8"
|
||||
|
||||
# OpenAI-compatible server endpoint; ``EMPTY`` works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# When True with ``backend=openai``, the CLI probes ``api_base`` and
|
||||
# spawns a server if none answers (default: ``transformers serve``).
|
||||
# Set to False to fail fast when pointing at a remote endpoint.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command. ``{port}`` is substituted per replica
|
||||
# when ``parallel_servers > 1``.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Run multiple independent inference servers for round-robin client
|
||||
# routing (each pinned to a GPU via ``CUDA_VISIBLE_DEVICES`` and bound
|
||||
# to ``serve_port + i``). ``num_gpus=0`` means one GPU per replica.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
|
||||
# Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||
gpu_memory_utilization: float = 0.9
|
||||
# Cap context length (None = model default). On 80 GB H100 a 30B BF16
|
||||
# model often needs <= 8192 to leave KV-cache headroom.
|
||||
max_model_len: int | None = None
|
||||
trust_remote_code: bool = False
|
||||
|
||||
# Override the camera stream used for keyframe attachment. None picks
|
||||
# the first ``observation.images.*`` key the dataset declares.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as ``extra_body.chat_template_kwargs`` on every chat call;
|
||||
# use to pass model-specific flags such as ``{"enable_thinking": false}``.
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotation/run_hf_job.py``); this config only controls
|
||||
intra-process episode concurrency.
|
||||
"""
|
||||
|
||||
# Episodes processed concurrently within each module phase. Each
|
||||
# in-flight episode dispatches 3-5 dependent VLM calls, so this is the
|
||||
# main knob for saturating ``parallel_servers`` and ``client_concurrency``.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate``.
|
||||
|
||||
The writer rewrites ``data/chunk-*/file-*.parquet`` in place. Multiple
|
||||
revisions of the same dataset live in separate copies.
|
||||
"""
|
||||
|
||||
# Hub dataset id. Used as the download source when ``root`` is unset,
|
||||
# and as the destination repo when ``push_to_hub`` is enabled and
|
||||
# ``dest_repo_id`` is unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Optional separate Hub dataset id to push the annotated result to. When
|
||||
# unset, ``push_to_hub`` uploads back to ``repo_id`` (annotate in place);
|
||||
# when set, the source ``repo_id`` is left untouched.
|
||||
dest_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/`` when unset.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
vocabulary: VocabularyConfig = field(default_factory=VocabularyConfig)
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend. When unset, the pipeline decodes with the
|
||||
# ffmpeg CLI: it decodes AV1 and runs each decode as an isolated child
|
||||
# process, which is both crash-safe and safe under the concurrent
|
||||
# decode the executor performs (torchcodec is not thread-safe and
|
||||
# SIGSEGVs there). Set to ``"torchcodec"`` or ``"pyav"`` to pin an
|
||||
# in-process decoder when its build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
|
||||
# When True, upload the annotated dataset to the Hugging Face Hub:
|
||||
# to ``dest_repo_id`` if set, otherwise back to ``repo_id``. One of
|
||||
# the two must be set for this to take effect.
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
322
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
322
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
@@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor plans **seven phases** in the dependency order from the plan:
|
||||
|
||||
phase 0: vocabulary discovery — derive a small canonical vocabulary
|
||||
from the first few sample-episode videos (subtask labels +
|
||||
memory milestones) and persist it next to the dataset; the
|
||||
``plan`` module then constrains every per-episode generation
|
||||
to those strings, so the downstream policy sees a small,
|
||||
repeatable conditioning distribution
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
vocabulary: Any = None # VocabularyDiscoveryModule | None
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 0: vocabulary discovery. Mutates ``self.plan.vocabulary``
|
||||
# so subsequent per-episode plan calls see the canonical labels.
|
||||
phases.append(self._run_vocabulary_phase(records, root))
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_vocabulary_phase(
|
||||
self, records: list[EpisodeRecord], root: Path
|
||||
) -> PhaseResult:
|
||||
"""Discover (or load) the canonical vocabulary, wire it into ``self.plan``.
|
||||
|
||||
Returns a ``PhaseResult`` whose ``episodes_processed`` is the number
|
||||
of sample episodes consulted (0 when disabled or no VLM call was
|
||||
needed); ``episodes_skipped`` is always ``0`` because vocabulary is
|
||||
a once-per-dataset artifact, not a per-episode product.
|
||||
"""
|
||||
from .vocabulary import load_vocabulary, save_vocabulary # noqa: PLC0415
|
||||
|
||||
if self.vocabulary is None or not getattr(self.vocabulary, "enabled", False):
|
||||
print(
|
||||
"[annotate] phase=vocabulary skipped (module disabled or unset)",
|
||||
flush=True,
|
||||
)
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
existing = load_vocabulary(root)
|
||||
if existing is not None and self.config.vocabulary.reuse_existing:
|
||||
print(
|
||||
f"[annotate] phase=vocabulary reusing {root / 'meta' / 'canonical_vocabulary.json'} "
|
||||
f"({len(existing.subtasks)} subtask labels, "
|
||||
f"{len(existing.memory_milestones)} memory milestones)",
|
||||
flush=True,
|
||||
)
|
||||
self.plan.vocabulary = existing
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
sample_n = max(1, min(int(self.config.vocabulary.sample_episodes), len(records)))
|
||||
print(
|
||||
f"[annotate] phase=vocabulary discovering from {sample_n} sample episode(s)...",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
vocab = self.vocabulary.discover(records[:sample_n], existing=existing)
|
||||
if vocab is None:
|
||||
print(
|
||||
"[annotate] phase=vocabulary returned no vocabulary — "
|
||||
"plan module will fall back to free-form generation",
|
||||
flush=True,
|
||||
)
|
||||
return PhaseResult(name="vocabulary", episodes_processed=0, episodes_skipped=0)
|
||||
|
||||
save_path = save_vocabulary(root, vocab)
|
||||
print(
|
||||
f"[annotate] phase=vocabulary wrote {save_path} "
|
||||
f"({len(vocab.subtasks)} subtask labels, "
|
||||
f"{len(vocab.memory_milestones)} memory milestones) in "
|
||||
f"{time.time() - t0:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
self.plan.vocabulary = vocab
|
||||
return PhaseResult(name="vocabulary", episodes_processed=sample_n, episodes_skipped=0)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(
|
||||
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
)
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
483
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
483
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
@@ -0,0 +1,483 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
|
||||
# concurrency- and crash-safe default for the pipeline's threaded
|
||||
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
|
||||
# decoder when the build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# ``camera_keys`` covers both image- and video-stored cameras and is
|
||||
# always defined on the metadata (``[]`` in the worst case), so it is
|
||||
# the single source we need here.
|
||||
keys = list(self._meta.camera_keys)
|
||||
# Last-resort fallback: if metadata didn't surface anything but the
|
||||
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
|
||||
# them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips
|
||||
re-extract when the cached clip already exists. Re-encodes to
|
||||
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||
downstream video processor — stream-copy would inherit the
|
||||
source codec (often AV1 in modern LeRobot datasets), which
|
||||
vllm's libav build cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, timeout=300)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
|
||||
# ThreadPoolExecutor and the in-process decoders are unsafe there:
|
||||
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
|
||||
# (a crash no try/except can catch), PyAV can likewise segfault on
|
||||
# AV1, and lerobot's ``pyav`` backend routes through the removed
|
||||
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
|
||||
# out per frame: each decode is an isolated child process, so it is
|
||||
# both crash-safe and concurrency-safe. ``video_backend`` can pin
|
||||
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
|
||||
# build is safe.
|
||||
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend == "ffmpeg":
|
||||
return _decode_frames_ffmpeg(video_path, shifted)
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as e: # noqa: PERF203
|
||||
exc = e
|
||||
|
||||
# Every backend raised. Log loudly the first time so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = getattr(self, "_warned_decode_fail", False)
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s "
|
||||
"video_path=%s backends=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
|
||||
|
||||
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
|
||||
piping a single PNG to stdout. Unlike the in-process decoders this
|
||||
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
|
||||
modern LeRobot datasets use) where torchcodec raises and PyAV can
|
||||
SIGSEGV, and a crash stays isolated to the child process — a non-zero
|
||||
exit is a catchable error, not a segfault of the whole job. Returns one
|
||||
``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import io # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
import numpy as np # noqa: PLC0415
|
||||
|
||||
frames: list[Any] = []
|
||||
for ts in timestamps:
|
||||
proc = subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-nostdin", "-loglevel", "error",
|
||||
"-ss", f"{max(ts, 0.0):.3f}",
|
||||
"-i", str(video_path),
|
||||
"-frames:v", "1",
|
||||
"-f", "image2pipe", "-vcodec", "png", "pipe:1",
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=120,
|
||||
)
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
|
||||
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
|
||||
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
|
||||
return frames
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
|
||||
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
|
||||
fallback; this stays available behind ``video_backend="pyav"``. Returns
|
||||
one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not getattr(self, "_warned_no_camera", False):
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
"""
|
||||
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(
|
||||
record, [frame_timestamp], camera_key=camera_key
|
||||
)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("module_2_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("module_2_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(0.0, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -0,0 +1,617 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..vocabulary import Vocabulary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
vocabulary: Vocabulary | None = None
|
||||
"""When set, the module constrains subtask + memory generation to the
|
||||
canonical strings in ``vocabulary``. Phase 0 (vocabulary discovery)
|
||||
populates this once per dataset; ``None`` falls back to free-form
|
||||
generation (original behaviour)."""
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Resolve the task that drives every other ``plan``-module prompt.
|
||||
# May be the canonical ``record.episode_task`` (default), or a fresh
|
||||
# description derived from the video when the canonical task is
|
||||
# empty / placeholder / forced-off (see PlanConfig.derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||
# message renderer rotates ``${task}`` deterministically through
|
||||
# them so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
# Always include the effective task itself as the first variant
|
||||
# so the rotation is guaranteed to cover the source-of-truth
|
||||
# phrasing, not just synthetic alternatives.
|
||||
seen: set[str] = set()
|
||||
ordered = [effective_task, *rephrasings]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary — including t=0 (start of
|
||||
# the first subtask). Because the plan is just a numbered list
|
||||
# of *still-todo* subtasks, re-emitting at each boundary makes
|
||||
# the active plan shrink as work progresses: at frame t the
|
||||
# rendered ``${plan}`` is the most recent emission, which
|
||||
# contains exactly the subtasks that started at or after the
|
||||
# current span. Saves the runtime from having to derive
|
||||
# "what's still left" at inference time.
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers (factored out: every ``plan``-module prompt below follows
|
||||
# the same "build messages → single VLM call → pull a named field"
|
||||
# shape, only differing in field name + post-processing).
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||
"""User message combining the episode video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
return (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections — subtask generation
|
||||
runs ~1 Hz at inference, but plan re-emission is event-driven.
|
||||
Now also forwards the interjection's own text into the prompt so
|
||||
the refreshed plan can actually reflect the user's correction
|
||||
(the previous version told the model "an interjection happened"
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("plan")
|
||||
# Pass the episode's last frame timestamp so the final subtask
|
||||
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||
# zero duration, and the "current subtask at refresh_t" lookup
|
||||
# in ``_generate_plan`` misses any refresh that lands inside it).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
vocabulary_block=self._subtask_vocabulary_block(),
|
||||
)
|
||||
messages = self._video_message(record, prompt)
|
||||
spans = self._vlm_field(messages, "subtasks")
|
||||
# When a vocabulary is in force, do a single targeted retry if
|
||||
# any returned subtask is off-vocab — strict exact-match only,
|
||||
# no fuzzy snapping. The retry includes the offending strings
|
||||
# and the full canonical list so the VLM can correct itself.
|
||||
if self.vocabulary is not None and self.vocabulary.subtasks and spans:
|
||||
invalid = self._invalid_subtasks(spans)
|
||||
if invalid:
|
||||
logger.info(
|
||||
"episode %d: VLM emitted %d off-vocab subtask(s) (%s); retrying once",
|
||||
record.episode_index,
|
||||
len(invalid),
|
||||
invalid,
|
||||
)
|
||||
retry_msg = self._build_subtask_retry_message(messages, invalid)
|
||||
retried = self._vlm_field(retry_msg, "subtasks")
|
||||
if retried:
|
||||
spans = retried
|
||||
|
||||
if not spans:
|
||||
return []
|
||||
# clamp to [t0, t_last] and sort
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(t0, min(start, t_last))
|
||||
end = max(t0, min(end, t_last))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
text = self._canonicalize_subtask(text)
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
cleaned = self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
if self.vocabulary is not None and self.vocabulary.subtasks and not cleaned:
|
||||
logger.warning(
|
||||
"episode %d: every VLM subtask was off-vocab even after retry — "
|
||||
"episode left empty (extend meta/canonical_vocabulary.json to "
|
||||
"cover the missing phase)",
|
||||
record.episode_index,
|
||||
)
|
||||
return cleaned
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Canonical-vocabulary helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _subtask_vocabulary_block(self) -> str:
|
||||
"""Bullet-list of canonical subtasks the VLM must pick from.
|
||||
|
||||
Returns an empty string when no vocabulary is configured —
|
||||
``module_1_subtasks.txt`` then falls back to its free-form
|
||||
rules (original behaviour).
|
||||
"""
|
||||
if self.vocabulary is None or not self.vocabulary.subtasks:
|
||||
return ""
|
||||
bullets = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
|
||||
return (
|
||||
"You MUST choose each subtask label verbatim from this canonical "
|
||||
"vocabulary — pick the closest match for each phase of the demo, "
|
||||
"and reuse the SAME string every time that phase recurs. The "
|
||||
"low-level policy is conditioned on these exact strings; any "
|
||||
"novel paraphrase you invent will make its conditioning OOD.\n"
|
||||
"Canonical subtask labels:\n"
|
||||
f"{bullets}\n\n"
|
||||
)
|
||||
|
||||
def _memory_vocabulary_block(self) -> str:
|
||||
"""Bullet-list of canonical memory milestones the VLM must pick from."""
|
||||
if self.vocabulary is None or not self.vocabulary.memory_milestones:
|
||||
return ""
|
||||
bullets = "\n".join(f"- {m}" for m in self.vocabulary.memory_milestones)
|
||||
return (
|
||||
"Compose the memory by picking ONLY from this canonical milestone "
|
||||
"list — append a milestone (or rewrite the running memory to "
|
||||
"compress past ones) using these exact phrases. Do not invent new "
|
||||
"wording: every paraphrase weakens the downstream conditioning.\n"
|
||||
"Canonical memory milestones:\n"
|
||||
f"{bullets}\n\n"
|
||||
)
|
||||
|
||||
_NORMALIZE_STRIP_TOKENS: frozenset[str] = frozenset({"the", "a", "an"})
|
||||
|
||||
def _canonicalize_subtask(self, text: str) -> str:
|
||||
"""Validate ``text`` against the canonical vocabulary; no fuzzy snap.
|
||||
|
||||
Without a vocabulary, the original text passes through. With a
|
||||
vocabulary, accept the span only if its normalised form (lower-
|
||||
cased, articles stripped, whitespace collapsed) matches a
|
||||
canonical entry exactly — the canonical wording is returned so
|
||||
the supervised string is byte-identical across episodes.
|
||||
|
||||
Off-vocab spans are dropped (empty string). Upstream
|
||||
``_generate_subtasks`` triggers a targeted retry before reaching
|
||||
the drop path; this function never snaps or warps a span into
|
||||
a different label.
|
||||
"""
|
||||
if self.vocabulary is None or not self.vocabulary.subtasks:
|
||||
return text.strip()
|
||||
normalised = self._normalize(text)
|
||||
if not normalised:
|
||||
return ""
|
||||
for candidate in self.vocabulary.subtasks:
|
||||
if self._normalize(candidate) == normalised:
|
||||
return candidate
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _normalize(cls, text: str) -> str:
|
||||
"""Lowercase, strip articles, collapse whitespace, drop punctuation."""
|
||||
words = [
|
||||
w.strip(".,:;\"'!?()")
|
||||
for w in text.lower().replace(",", " ").split()
|
||||
]
|
||||
return " ".join(w for w in words if w and w not in cls._NORMALIZE_STRIP_TOKENS)
|
||||
|
||||
def _invalid_subtasks(self, spans: list[dict[str, Any]]) -> list[str]:
|
||||
"""Return the unique off-vocab subtask strings the VLM produced."""
|
||||
seen: list[str] = []
|
||||
for span in spans:
|
||||
text = str((span or {}).get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
if self._canonicalize_subtask(text):
|
||||
continue
|
||||
if text not in seen:
|
||||
seen.append(text)
|
||||
return seen
|
||||
|
||||
def _build_subtask_retry_message(
|
||||
self, original_messages: list[dict[str, Any]], invalid: list[str]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Compose a one-shot correction prompt naming the off-vocab strings."""
|
||||
assert self.vocabulary is not None
|
||||
canonical = "\n".join(f"- {s}" for s in self.vocabulary.subtasks)
|
||||
invalid_list = "\n".join(f"- {s!r}" for s in invalid)
|
||||
correction = (
|
||||
"Your previous response included subtask labels that are NOT in "
|
||||
"the canonical vocabulary:\n"
|
||||
f"{invalid_list}\n\n"
|
||||
"Re-emit the same segmentation (same number of spans, same start/end "
|
||||
"timestamps where they were valid) but replace every off-vocab "
|
||||
"label with the EXACT canonical string for that phase, copied "
|
||||
"verbatim from this list:\n"
|
||||
f"{canonical}\n\n"
|
||||
"Strict rules:\n"
|
||||
"- Output strings must be byte-for-byte identical to entries above.\n"
|
||||
"- No articles, no adverbs, no extra words.\n"
|
||||
"- If a phase truly has no canonical match, omit that span entirely.\n"
|
||||
"Return the same JSON shape as before."
|
||||
)
|
||||
# Append the correction as an additional user turn; the model
|
||||
# sees the original prompt + its prior output is implied by the
|
||||
# conversation context (the VLM client is stateless, so we
|
||||
# re-send the original content plus this correction).
|
||||
retry_messages = [
|
||||
{
|
||||
"role": m.get("role", "user"),
|
||||
"content": (
|
||||
m.get("content")
|
||||
if isinstance(m.get("content"), str)
|
||||
else list(m.get("content") or [])
|
||||
),
|
||||
}
|
||||
for m in original_messages
|
||||
]
|
||||
retry_messages.append({"role": "user", "content": correction})
|
||||
return retry_messages
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
Previously this called the VLM with a prompt that asked it to
|
||||
compress the subtasks into a "compact hierarchical plan". That
|
||||
produced longer-than-necessary plans, cost an extra VLM round-trip
|
||||
per episode (plus one per interjection on refresh), and could
|
||||
diverge from the actual subtask sequence the model is going to
|
||||
execute. Replacing it with a plain summarisation keeps the plan
|
||||
tightly aligned with the upcoming subtasks and removes the VLM
|
||||
call entirely.
|
||||
|
||||
Layout — short imperative fragments prefixed by "N. ":
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
...
|
||||
|
||||
On a refresh at ``refresh_t`` (called from ``run_plan_updates``
|
||||
on interjection events, and from ``run_episode`` at every subtask
|
||||
boundary), only subtasks whose start is at or after ``refresh_t``
|
||||
are included — the plan shrinks as work progresses, so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s
|
||||
for s in subtask_spans
|
||||
if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(
|
||||
f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)
|
||||
)
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("module_1_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
vocabulary_block=self._memory_vocabulary_block(),
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -0,0 +1,53 @@
|
||||
You are inspecting {n_episodes} sample episode video(s) from a teleoperated
|
||||
robot dataset. Every episode in the dataset performs the SAME task; the
|
||||
user originally asked: "{episode_task}".
|
||||
|
||||
Watch all the clips and produce a SHORT canonical vocabulary that every
|
||||
episode in this dataset will reuse. The downstream low-level policy is
|
||||
conditioned on these strings — duplicate phrasings (e.g. "grasp blue
|
||||
cube" vs "pick up the blue cube") would destroy the conditioning, so
|
||||
pick one wording per concept and reuse it everywhere.
|
||||
|
||||
Decide how many entries each list needs YOURSELF based on what you see —
|
||||
the smallest set that still covers every recurring phase in the demos.
|
||||
A simple two-object pick-and-place might need ~6 subtask labels and 2
|
||||
memory milestones; a long multi-step recipe needs more. Err on the side
|
||||
of FEWER — extra entries that don't recur across episodes weaken the
|
||||
conditioning.
|
||||
|
||||
You output two lists:
|
||||
|
||||
1. `subtasks`: imperative, telegraphic commands the robot can execute.
|
||||
- Verb-first. Drop articles, adverbs, qualifiers.
|
||||
- Consistent object nouns (if the task says "cube", every subtask says
|
||||
"cube" — never "block" / "object").
|
||||
- Atomic — one skill per subtask (gripper-open events, contact, regrasps,
|
||||
transitions all become cut points).
|
||||
- Each label must recur across the demos. If you see a motion only
|
||||
once across all sample clips, it probably isn't a canonical phase.
|
||||
- Good: "move to blue cube", "grasp blue cube", "lift blue cube",
|
||||
"place blue cube in box", "release blue cube", "retract arm".
|
||||
- Bad: "the robot arm moves towards the blue cube" (third person,
|
||||
too long), "carefully pick up the cube" (adverb, article),
|
||||
"carrying the yellow cube over the green basket" (gerund — should
|
||||
be imperative "transport yellow cube to green basket").
|
||||
|
||||
2. `memory_milestones`: first-person past-tense sentences the running
|
||||
memory composes from. Each subtask phase that produces a lasting
|
||||
change should have a milestone; transient motions (move, retract)
|
||||
should NOT.
|
||||
- First person, past tense. Start with "I".
|
||||
- One sentence. Functional outcome only — no grasp / motion detail.
|
||||
- Good: "I picked up the blue cube.", "I placed the blue cube in
|
||||
the green box.", "I wiped the counter."
|
||||
- Bad: "The robot arm grasped the blue cube." (third person),
|
||||
"I carefully grasped the blue cube with the parallel gripper."
|
||||
(irrelevant detail), "I moved towards the blue cube." (transient
|
||||
motion — should be omitted, not memorialised).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": ["<verb phrase>", ...],
|
||||
"memory_milestones": ["I <past-tense sentence>.", ...]
|
||||
}}
|
||||
@@ -0,0 +1,36 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
{vocabulary_block}Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -0,0 +1,80 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{vocabulary_block}Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,17 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -0,0 +1,12 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -0,0 +1,46 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
|
||||
|
||||
def gather_data_paths(root: Path) -> list[Path]:
|
||||
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
|
||||
return sorted((root / "data").rglob("*.parquet"))
|
||||
|
||||
|
||||
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
|
||||
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
|
||||
table = pq.read_table(path, columns=["episode_index"])
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
out: dict[int, tuple[int, int]] = {}
|
||||
cur_ep: int | None = None
|
||||
start = 0
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start = i
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
out[cur_ep] = (start, i - start)
|
||||
cur_ep = ep
|
||||
start = i
|
||||
if cur_ep is not None:
|
||||
out[cur_ep] = (start, len(episode_col) - start)
|
||||
return out
|
||||
|
||||
|
||||
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
|
||||
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
|
||||
n = record.row_count
|
||||
if k <= 0 or n == 0:
|
||||
return []
|
||||
if k >= n:
|
||||
return list(range(n))
|
||||
step = (n - 1) / (k - 1) if k > 1 else 0.0
|
||||
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
|
||||
|
||||
|
||||
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
|
||||
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
|
||||
for path in gather_data_paths(root):
|
||||
offsets = episode_offsets_per_path(path)
|
||||
if episode_index in offsets:
|
||||
start, count = offsets[episode_index]
|
||||
return path, start, count
|
||||
return None
|
||||
|
||||
|
||||
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
|
||||
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
|
||||
found = lookup_data_path(root, episode_index)
|
||||
if found is None:
|
||||
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
|
||||
path, start, count = found
|
||||
table = pq.read_table(path, columns=["timestamp"])
|
||||
timestamps = table.column("timestamp").to_pylist()[start : start + count]
|
||||
return path, [float(t) for t in timestamps]
|
||||
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
|
||||
|
||||
def iter_staged_episodes(root: Path) -> Iterator[int]:
|
||||
"""Yield episode indices for which any staging artifact exists."""
|
||||
if not root.exists():
|
||||
return
|
||||
for child in sorted(root.iterdir()):
|
||||
if child.is_dir() and child.name.startswith("episode_"):
|
||||
try:
|
||||
yield int(child.name.removeprefix("episode_"))
|
||||
except ValueError:
|
||||
continue
|
||||
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(
|
||||
row, report, record.episode_index, self.dataset_camera_keys
|
||||
)
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||
)
|
||||
return
|
||||
if (
|
||||
is_view_dependent_style(style)
|
||||
and dataset_camera_keys
|
||||
and camera not in dataset_camera_keys
|
||||
):
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
@@ -0,0 +1,703 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output (``json_mode=True`` enables guided decoding when the
|
||||
backend supports it),
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client per the configured backend.
|
||||
|
||||
For ``stub``, callers should construct :class:`StubVlmClient` directly with
|
||||
a responder callable. ``stub`` here is rejected to make accidental misuse
|
||||
obvious.
|
||||
"""
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend == "vllm":
|
||||
return _make_vllm_client(config)
|
||||
if config.backend == "transformers":
|
||||
return _make_transformers_client(config)
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
if os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"tensor_parallel_size": config.tensor_parallel_size,
|
||||
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||
"trust_remote_code": config.trust_remote_code,
|
||||
}
|
||||
if config.max_model_len is not None:
|
||||
llm_kwargs["max_model_len"] = config.max_model_len
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||
# ``llm.chat`` handles chat-template application + multimodal input
|
||||
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||
# does not.
|
||||
outputs = llm.chat([list(m) for m in batch], params)
|
||||
return [o.outputs[0].text for o in outputs]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
import transformers # type: ignore[import-not-found]
|
||||
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||
auto_cls = getattr(transformers, "AutoModelForImageTextToText", None) or getattr(
|
||||
transformers, "AutoModelForVision2Seq", None
|
||||
)
|
||||
if auto_cls is None:
|
||||
raise ImportError(
|
||||
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||
"for VL models."
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
|
||||
use_accelerate = os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||
if use_accelerate:
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype=_torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
model = model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(text=[text], return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
gen = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tok,
|
||||
temperature=temp,
|
||||
do_sample=temp > 0.0,
|
||||
)
|
||||
decoded = processor.batch_decode(
|
||||
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
|
||||
)[0]
|
||||
outs.append(decoded)
|
||||
return outs
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||
"""Pass-through hook used by the vllm backend.
|
||||
|
||||
vllm exposes its own multimodal entry points that vary by version; for the
|
||||
base flow we simply forward the raw message list and let the caller's
|
||||
custom backend handle templating. Real deployments override this.
|
||||
"""
|
||||
return list(messages)
|
||||
222
src/lerobot/annotations/steerable_pipeline/vocabulary.py
Normal file
222
src/lerobot/annotations/steerable_pipeline/vocabulary.py
Normal file
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dataset-level canonical vocabulary discovery (Phase 0).
|
||||
|
||||
The downstream consumer of these annotations is a low-level action expert
|
||||
conditioned on the ``subtask`` string. Free-form per-episode LLM rephrasing
|
||||
gives near-unique strings per occurrence, which collapses the action
|
||||
expert's conditioning to noise and makes runtime subtask-paraphrase drift
|
||||
catastrophic. The Hi-Robot / π0.6-MEM recipe ships a small canonical
|
||||
vocabulary per environment (~10 strings) that every episode reuses; this
|
||||
module derives that vocabulary automatically from the first few episode
|
||||
videos and persists it next to the dataset.
|
||||
|
||||
Pipeline-level flow:
|
||||
|
||||
Phase 0 (here): watch N sample episodes → produce vocabulary.json
|
||||
Phase 1 (plan module): reuse vocabulary on every episode, both as
|
||||
prompt-side constraint *and* post-VLM validation
|
||||
|
||||
The vocabulary is JSON, lives at ``<root>/meta/canonical_vocabulary.json``,
|
||||
and is human-inspectable / hand-editable — if the discovered set is wrong,
|
||||
operators edit the file and re-run the pipeline without phase 0.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import VocabularyConfig
|
||||
from .frames import FrameProvider, null_provider, to_video_block
|
||||
from .prompts import load as load_prompt
|
||||
from .reader import EpisodeRecord
|
||||
from .vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCABULARY_FILENAME = "canonical_vocabulary.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Vocabulary:
|
||||
"""Canonical phrasings shared across every episode of one dataset.
|
||||
|
||||
Both lists are strict: per-episode subtask + memory generation pick
|
||||
from these strings only; the downstream policy then has a small,
|
||||
repeatable target distribution to learn instead of thousands of
|
||||
LLM paraphrases.
|
||||
"""
|
||||
|
||||
subtasks: tuple[str, ...]
|
||||
"""Imperative subtask labels — what the low-level policy is conditioned
|
||||
on. Verb-first, telegraphic, consistent object nouns. Example:
|
||||
``("move to blue cube", "grasp blue cube", "lift blue cube",
|
||||
"place blue cube in box", "retract arm")``.
|
||||
"""
|
||||
|
||||
memory_milestones: tuple[str, ...]
|
||||
"""First-person past-tense milestone sentences — building blocks for
|
||||
the running memory string. Example: ``("I picked up the blue cube.",
|
||||
"I placed the blue cube in the green box.")``. Each milestone maps
|
||||
1:1 onto a completed subtask phase; ``memory_at_step_k`` is the
|
||||
concatenation of milestones for completed phases.
|
||||
"""
|
||||
|
||||
def to_json(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"subtasks": list(self.subtasks),
|
||||
"memory_milestones": list(self.memory_milestones),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, payload: dict[str, Any]) -> Vocabulary:
|
||||
subtasks = tuple(
|
||||
str(s).strip() for s in (payload.get("subtasks") or []) if str(s).strip()
|
||||
)
|
||||
memory_milestones = tuple(
|
||||
str(s).strip() for s in (payload.get("memory_milestones") or []) if str(s).strip()
|
||||
)
|
||||
return cls(subtasks=subtasks, memory_milestones=memory_milestones)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.subtasks and not self.memory_milestones
|
||||
|
||||
|
||||
def vocabulary_path(root: Path) -> Path:
|
||||
"""Return the canonical on-disk location for the vocabulary file."""
|
||||
return root / "meta" / VOCABULARY_FILENAME
|
||||
|
||||
|
||||
def load_vocabulary(root: Path) -> Vocabulary | None:
|
||||
"""Read ``<root>/meta/canonical_vocabulary.json`` if present.
|
||||
|
||||
Returns ``None`` when the file does not exist — callers fall back to
|
||||
free-form (unconstrained) subtask + memory generation, preserving the
|
||||
pipeline's behaviour on datasets that never ran phase 0.
|
||||
"""
|
||||
path = vocabulary_path(root)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
logger.warning("could not read %s: %s — proceeding without vocabulary", path, exc)
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning("%s is not a JSON object — ignoring", path)
|
||||
return None
|
||||
vocab = Vocabulary.from_json(payload)
|
||||
if vocab.is_empty():
|
||||
return None
|
||||
return vocab
|
||||
|
||||
|
||||
def save_vocabulary(root: Path, vocab: Vocabulary) -> Path:
|
||||
"""Atomically persist ``vocab`` to ``<root>/meta/canonical_vocabulary.json``."""
|
||||
path = vocabulary_path(root)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(
|
||||
json.dumps(vocab.to_json(), indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
tmp.replace(path)
|
||||
return path
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabularyDiscoveryModule:
|
||||
"""Derive a dataset-level canonical vocabulary from sample episodes.
|
||||
|
||||
Phase 0 of the executor: pulls ``config.sample_episodes`` episode
|
||||
videos, packs them into one Qwen-VL multi-video prompt, and asks the
|
||||
model to enumerate the small set of canonical subtask labels +
|
||||
memory milestones that recur across them. The output is persisted
|
||||
to ``meta/canonical_vocabulary.json`` and consumed by phase 1.
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VocabularyConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def discover(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
*,
|
||||
existing: Vocabulary | None = None,
|
||||
) -> Vocabulary | None:
|
||||
"""Run vocabulary discovery against the first N sample episodes.
|
||||
|
||||
``existing`` short-circuits the VLM call when ``config.reuse_existing``
|
||||
is True and an on-disk vocabulary is already present — keeps re-runs
|
||||
cheap and lets operators hand-edit the file without it getting
|
||||
overwritten.
|
||||
"""
|
||||
if existing is not None and self.config.reuse_existing:
|
||||
logger.info(
|
||||
"vocabulary: reusing existing (%d subtasks, %d memory milestones)",
|
||||
len(existing.subtasks),
|
||||
len(existing.memory_milestones),
|
||||
)
|
||||
return existing
|
||||
|
||||
sample = list(records[: max(1, int(self.config.sample_episodes))])
|
||||
if not sample:
|
||||
return None
|
||||
|
||||
task_hint = next((r.episode_task for r in sample if r.episode_task), "")
|
||||
prompt = load_prompt("module_0_vocabulary").format(
|
||||
episode_task=task_hint or "(unspecified)",
|
||||
n_episodes=len(sample),
|
||||
)
|
||||
# Pack one video block per sample episode so the VLM sees the
|
||||
# variation across episodes (different starting poses, different
|
||||
# object placements) rather than overfitting to one trajectory.
|
||||
content: list[dict[str, Any]] = []
|
||||
for record in sample:
|
||||
video_frames = self.frame_provider.video_for_episode(
|
||||
record, int(self.config.max_video_frames_per_episode)
|
||||
)
|
||||
if video_frames:
|
||||
content.extend(to_video_block(video_frames))
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
logger.warning("vocabulary: VLM did not return a JSON object — skipping")
|
||||
return None
|
||||
|
||||
vocab = Vocabulary.from_json(result)
|
||||
if vocab.is_empty():
|
||||
logger.warning("vocabulary: VLM returned an empty vocabulary — skipping")
|
||||
return None
|
||||
logger.info(
|
||||
"vocabulary: discovered %d subtask labels + %d memory milestones from %d episodes",
|
||||
len(vocab.subtasks),
|
||||
len(vocab.memory_milestones),
|
||||
len(sample),
|
||||
)
|
||||
return vocab
|
||||
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
@@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Surface a friendly error from the writer rather than letting
|
||||
# the raw KeyError bubble out of the dict access below — modules
|
||||
# are expected to always emit ``role``, but the validator
|
||||
# currently doesn't check this so a future bug would otherwise
|
||||
# be hard to triage.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_rows_for_writer(
|
||||
rows: Iterable[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Helper used by tests/validators to partition a flat row list into
|
||||
(persistent_rows, event_rows) using ``column_for_style``.
|
||||
"""
|
||||
persistent: list[dict[str, Any]] = []
|
||||
events: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
return persistent, events
|
||||
205
src/lerobot/scripts/lerobot_annotate.py
Normal file
205
src/lerobot/scripts/lerobot_annotate.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-annotate`` — populate ``language_persistent`` and
|
||||
``language_events`` columns on a LeRobot dataset.
|
||||
|
||||
Annotations live directly in ``data/chunk-*/file-*.parquet``.
|
||||
|
||||
Example:
|
||||
|
||||
uv run lerobot-annotate \\
|
||||
--root=/path/to/dataset \\
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
For distributed runs, see ``examples/annotations/run_hf_job.py``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
|
||||
from lerobot.annotations.steerable_pipeline.vocabulary import VocabularyDiscoveryModule
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs import parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
|
||||
if cfg.root is not None:
|
||||
return Path(cfg.root)
|
||||
if cfg.repo_id is not None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
|
||||
raise ValueError("Either --root or --repo_id must be provided.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Run the steerable annotation pipeline against a dataset."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
root = _resolve_root(cfg)
|
||||
logger.info("annotate: root=%s", root)
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
frame_provider = make_frame_provider(
|
||||
root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend
|
||||
)
|
||||
# Surface the resolved cameras up front so a silent vqa-module no-op
|
||||
# is obvious in job output rather than discovered post-hoc by counting
|
||||
# parquet rows.
|
||||
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||
logger.info(
|
||||
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||
getattr(frame_provider, "camera_key", None),
|
||||
cam_keys,
|
||||
)
|
||||
if cfg.vqa.enabled and not cam_keys:
|
||||
logger.warning(
|
||||
"annotate: the vqa module is enabled but no cameras were "
|
||||
"resolved — it will produce zero VQA rows. Check "
|
||||
"meta/info.json for observation.images.* features, or pass "
|
||||
"--vlm.camera_key=<key> to seed the cameras list."
|
||||
)
|
||||
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
|
||||
interjections = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
vocabulary = VocabularyDiscoveryModule(
|
||||
vlm=vlm, config=cfg.vocabulary, frame_provider=frame_provider
|
||||
)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
vocabulary=vocabulary,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
summary = executor.run(root)
|
||||
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
|
||||
for phase in summary.phases:
|
||||
logger.info(
|
||||
"annotate: phase=%s processed=%d skipped=%d",
|
||||
phase.name,
|
||||
phase.episodes_processed,
|
||||
phase.episodes_skipped,
|
||||
)
|
||||
if summary.validation_report.warnings:
|
||||
for w in summary.validation_report.warnings:
|
||||
logger.warning(w)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
if cfg.repo_id is None and cfg.dest_repo_id is None:
|
||||
raise ValueError(
|
||||
"--push_to_hub requires --repo_id or --dest_repo_id (the dataset repo to push to)."
|
||||
)
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
|
||||
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Upload the annotated dataset directory to the Hub.
|
||||
|
||||
Pushes to ``cfg.dest_repo_id`` when set, otherwise back to ``cfg.repo_id``.
|
||||
"""
|
||||
from huggingface_hub import HfApi # noqa: PLC0415
|
||||
|
||||
repo_id = cfg.dest_repo_id or cfg.repo_id
|
||||
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||
api = HfApi()
|
||||
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||
api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
private=cfg.push_private,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||
commit_info = api.upload_folder(
|
||||
folder_path=str(root),
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||
)
|
||||
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||
|
||||
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||
# for tags like ``v3.0``; without a tag it raises
|
||||
# ``RevisionNotFoundError``. Read the version straight from the
|
||||
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||
# actually wrote (no accidental drift if the codebase floor moves).
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
version_tag = CODEBASE_VERSION
|
||||
if info_path.exists():
|
||||
try:
|
||||
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||
|
||||
info = load_json(info_path)
|
||||
ds_version = info.get("codebase_version")
|
||||
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||
version_tag = ds_version
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}", flush=True)
|
||||
revision = getattr(commit_info, "oid", None)
|
||||
tag_kwargs = {
|
||||
"repo_id": repo_id,
|
||||
"tag": version_tag,
|
||||
"repo_type": "dataset",
|
||||
"exist_ok": True,
|
||||
}
|
||||
if revision is not None:
|
||||
tag_kwargs["revision"] = revision
|
||||
|
||||
try:
|
||||
api.create_tag(**tag_kwargs)
|
||||
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||
"Run: from huggingface_hub import HfApi; "
|
||||
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
annotate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
tests/annotations/__init__.py
Normal file
0
tests/annotations/__init__.py
Normal file
58
tests/annotations/_helpers.py
Normal file
58
tests/annotations/_helpers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helpers shared across annotation-pipeline tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
|
||||
def make_canned_responder(
|
||||
responses_by_marker: dict[str, Any],
|
||||
default: Any = None,
|
||||
) -> StubVlmClient:
|
||||
"""Return a stub that picks a response by inspecting the user prompt.
|
||||
|
||||
For each call the responder examines the last user-message text and
|
||||
returns the response keyed by the first marker substring it contains.
|
||||
Falls back to ``default`` if no marker matches.
|
||||
"""
|
||||
|
||||
def responder(messages: list[dict[str, Any]]) -> Any:
|
||||
last_user_text = ""
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
last_user_text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
last_user_text = block.get("text", "")
|
||||
for marker, response in responses_by_marker.items():
|
||||
if marker in last_user_text:
|
||||
return response
|
||||
return default
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def encode_vqa_answer(payload: dict[str, Any]) -> str:
|
||||
return json.dumps(payload, sort_keys=True)
|
||||
51
tests/annotations/conftest.py
Normal file
51
tests/annotations/conftest.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared fixtures for annotation-pipeline tests.
|
||||
|
||||
The on-disk dataset builder lives with the other dataset factories in
|
||||
``tests/fixtures/dataset_factories.py`` (:func:`build_annotation_dataset`);
|
||||
these fixtures only wire it into pytest.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds",
|
||||
episode_specs=[
|
||||
(0, 12, "Could you tidy the kitchen please?"),
|
||||
(1, 12, "Please clean up the kitchen"),
|
||||
],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_episode_root(tmp_path: Path) -> Path:
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds_one",
|
||||
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
101
tests/annotations/run_e2e_smoke.py
Normal file
101
tests/annotations/run_e2e_smoke.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||
|
||||
Builds the shared annotation fixture (:func:`build_annotation_dataset`),
|
||||
runs the full annotation pipeline against it with a stub VLM, and prints a
|
||||
short report. This is intentionally not a pytest test — it exercises the
|
||||
CLI plumbing — but it reuses the same on-disk dataset builder as the pytest
|
||||
fixtures so there is no duplicated fixture code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
|
||||
def _stub_responder(messages):
|
||||
text = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
if "atomic subtasks" in text:
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
|
||||
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
|
||||
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
}
|
||||
if "concise hierarchical PLAN" in text:
|
||||
return {"plan": "1. grasp\n2. pour\n3. place"}
|
||||
if "Update the memory" in text:
|
||||
return {"memory": "poured once"}
|
||||
if "acknowledgement the robot" in text:
|
||||
return {"text": "Sure."}
|
||||
if "ONE realistic interruption" in text:
|
||||
return {"interjection": "use less water", "speech": "Using less water."}
|
||||
if "frame-grounded visual question" in text:
|
||||
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
|
||||
return None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = build_annotation_dataset(
|
||||
Path(tmp) / "ds",
|
||||
episode_specs=[(0, 30, "Pour water into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
vlm = StubVlmClient(responder=_stub_responder)
|
||||
cfg = AnnotationPipelineConfig()
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
summary = executor.run(root)
|
||||
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
|
||||
print(f"validation: {summary.validation_report.summary()}")
|
||||
print(f"shards rewritten: {len(summary.written_paths)}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
146
tests/annotations/test_frames.py
Normal file
146
tests/annotations/test_frames.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for :class:`VideoFrameProvider` method bindings.
|
||||
|
||||
These were prompted by a real regression: ``video_for_episode`` was once
|
||||
indented one level too deep so it ended up nested *inside* a module-level
|
||||
helper (after that function's ``return`` statement) — silently dead code
|
||||
that meant production runs with ``use_video_url=False`` would
|
||||
``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The
|
||||
existing module tests didn't catch it because they exercise stub providers.
|
||||
|
||||
The tests below assert on the class itself (not on an instance), so a
|
||||
future reindent regression flips them to red without needing a real
|
||||
LeRobot dataset on disk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402
|
||||
VideoFrameProvider,
|
||||
_decode_frames_av,
|
||||
_decode_frames_ffmpeg,
|
||||
)
|
||||
|
||||
|
||||
def test_video_for_episode_is_a_method_of_videoframeprovider():
|
||||
"""``video_for_episode`` must be a bound method, not nested dead code."""
|
||||
assert callable(getattr(VideoFrameProvider, "video_for_episode", None))
|
||||
|
||||
|
||||
def test_episode_clip_path_is_a_method_of_videoframeprovider():
|
||||
"""``episode_clip_path`` is now a method (was a free function reaching
|
||||
into ``provider._meta`` from outside the class)."""
|
||||
assert callable(getattr(VideoFrameProvider, "episode_clip_path", None))
|
||||
|
||||
|
||||
def test_videoframeprovider_has_a_lock_for_concurrent_use():
|
||||
"""A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases
|
||||
concurrently; the cache + warn-flag accesses must be guarded.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Fresh-instance check via a minimal fake to avoid touching the hub.
|
||||
# The lock is declared with ``init=False`` and has a default factory,
|
||||
# so a constructed instance must own a real ``threading.Lock``.
|
||||
lock_field = next(
|
||||
(f for f in VideoFrameProvider.__dataclass_fields__.values() if f.name == "_lock"),
|
||||
None,
|
||||
)
|
||||
assert lock_field is not None
|
||||
assert lock_field.default_factory is threading.Lock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video(tmp_path: Path) -> Path:
|
||||
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
pytest.skip("ffmpeg not available")
|
||||
out = tmp_path / "sample.mp4"
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-y", "-f", "lavfi",
|
||||
"-i", "testsrc=duration=3:size=160x120:rate=10",
|
||||
"-pix_fmt", "yuv420p", str(out),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
|
||||
"""``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision.
|
||||
|
||||
This is the always-available fallback: torchcodec is unusable in some
|
||||
containers and lerobot's ``pyav`` backend routes through the removed
|
||||
``torchvision.io.VideoReader``.
|
||||
"""
|
||||
timestamps = [0.0, 1.0, 2.5]
|
||||
frames = _decode_frames_av(sample_video, timestamps)
|
||||
|
||||
assert len(frames) == len(timestamps)
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.dtype == torch.uint8
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None:
|
||||
"""Repeated and out-of-order timestamps each resolve to the nearest frame."""
|
||||
frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0])
|
||||
|
||||
assert len(frames) == 3
|
||||
assert torch.equal(frames[0], frames[2])
|
||||
assert not torch.equal(frames[0], frames[1])
|
||||
|
||||
|
||||
def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None:
|
||||
"""A missing video surfaces as an exception the caller can fall back on."""
|
||||
with pytest.raises(Exception): # noqa: B017, PT011
|
||||
_decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])
|
||||
|
||||
|
||||
def test_decode_frames_ffmpeg_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
|
||||
"""``_decode_frames_ffmpeg`` shells out to the ffmpeg CLI — the always-
|
||||
available fallback that decodes AV1 and isolates crashes to a child
|
||||
process.
|
||||
"""
|
||||
timestamps = [0.0, 1.0, 2.5]
|
||||
frames = _decode_frames_ffmpeg(sample_video, timestamps)
|
||||
|
||||
assert len(frames) == len(timestamps)
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.dtype == torch.uint8
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_decode_frames_ffmpeg_raises_on_missing_file(tmp_path: Path) -> None:
|
||||
"""A missing video raises (non-zero ffmpeg exit), never crashes the job."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
pytest.skip("ffmpeg not available")
|
||||
with pytest.raises(Exception): # noqa: B017, PT011
|
||||
_decode_frames_ffmpeg(tmp_path / "does_not_exist.mp4", [0.0])
|
||||
355
tests/annotations/test_modules.py
Normal file
355
tests/annotations/test_modules.py
Normal file
@@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1/2/3 unit tests with stubbed VLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StubFrameProvider:
|
||||
"""Returns one sentinel object per requested timestamp."""
|
||||
|
||||
sentinel: Any = field(default_factory=lambda: object())
|
||||
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return list(self.cameras)
|
||||
|
||||
def frames_at(self, record, timestamps, camera_key=None):
|
||||
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||
return [self.sentinel] * len(timestamps)
|
||||
|
||||
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||
n = min(max_frames, len(record.frame_timestamps))
|
||||
return [self.sentinel] * n
|
||||
|
||||
|
||||
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
return reply
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"Update the memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig())
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||
# subtask timestamps must be exact frame timestamps
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for row in rows:
|
||||
assert row["timestamp"] in frame_set
|
||||
# one plan row per subtask boundary; the first lands at t0 and each
|
||||
# plan is the deterministic numbered list of still-todo subtasks
|
||||
plan_rows = sorted((r for r in rows if r["style"] == "plan"), key=lambda r: r["timestamp"])
|
||||
subtask_rows = [r for r in rows if r["style"] == "subtask"]
|
||||
assert len(plan_rows) == len(subtask_rows)
|
||||
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
|
||||
# the t0 plan enumerates all subtasks; later plans shrink
|
||||
assert plan_rows[0]["content"].startswith("1. ")
|
||||
assert len(plan_rows[0]["content"].splitlines()) == len(subtask_rows)
|
||||
assert len(plan_rows[-1]["content"].splitlines()) == 1
|
||||
|
||||
|
||||
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{"acknowledgement the robot": {"text": "Sure, on it."}},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=InterjectionsConfig(max_interjections_per_episode=0),
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("interjections")
|
||||
assert len(rows) == 1
|
||||
only = rows[0]
|
||||
assert only["role"] == "assistant"
|
||||
assert only["style"] is None
|
||||
assert only["content"] is None
|
||||
assert only["timestamp"] == record.frame_timestamps[0]
|
||||
assert only["tool_calls"][0]["function"]["name"] == "say"
|
||||
|
||||
|
||||
def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Module 2 anchors interjections on Module 1's subtask boundaries.
|
||||
|
||||
The executor runs Module 1 first, then Module 2 reads the subtask
|
||||
rows back from the same staging tree (see
|
||||
``_mid_episode_interjections``). Reproduce that contract here by
|
||||
seeding the staging with two subtask rows so a single ``0 → 1``
|
||||
boundary exists for Module 2 to anchor on.
|
||||
"""
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"acknowledgement the robot": {"text": "OK."},
|
||||
# Marker matches the distinctive line of
|
||||
# ``module_2_interjection.txt``. The old marker
|
||||
# ("ONE realistic interruption") came from a previous prompt
|
||||
# version that asked for counterfactual interjections; the
|
||||
# current design anchors on subtask boundaries instead, so
|
||||
# the prompt and its marker changed.
|
||||
"Write ONE interjection": {
|
||||
"interjection": "now wipe the counter please",
|
||||
"speech": "On it.",
|
||||
},
|
||||
},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
seed=7,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
# Seed Module 1's subtask staging so Module 2 has a boundary to
|
||||
# anchor on (it bails with zero rows when no spans exist — the
|
||||
# production executor guarantees Module 1 ran first).
|
||||
boundary_ts = float(record.frame_timestamps[len(record.frame_timestamps) // 2])
|
||||
staging.write(
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": float(record.frame_timestamps[0]),
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wipe the counter",
|
||||
"style": "subtask",
|
||||
"timestamp": boundary_ts,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("interjections")
|
||||
|
||||
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||
assert len(interjections) == 1
|
||||
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
|
||||
inter_t = interjections[0]["timestamp"]
|
||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||
|
||||
|
||||
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
frame_provider=_StubFrameProvider(cameras=("observation.images.top", "observation.images.wrist")),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("vqa")
|
||||
# every vqa row must carry a camera tag and one of the configured cameras
|
||||
for r in rows:
|
||||
assert r["style"] == "vqa"
|
||||
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||
user_keys = [(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
||||
assistant_keys = [
|
||||
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "assistant" and r["style"] == "vqa"
|
||||
]
|
||||
assert len(user_keys) == len(set(user_keys))
|
||||
assert len(assistant_keys) == len(set(assistant_keys))
|
||||
# both cameras must be represented
|
||||
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||
# every emitted timestamp must be an exact source frame timestamp
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for ts, _ in user_keys + assistant_keys:
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Module 1 sends one ``type=video`` block covering the whole episode."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.5},
|
||||
{"text": "wipe the counter", "start": 0.5, "end": 1.1},
|
||||
]
|
||||
}
|
||||
plan_payload = {"plan": "1. grasp\n2. wipe"}
|
||||
memory_payload = {"memory": "wiped once"}
|
||||
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
text = ""
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if "concise hierarchical PLAN" in text:
|
||||
return plan_payload
|
||||
if "Update the memory" in text:
|
||||
return memory_payload
|
||||
return payload
|
||||
|
||||
provider = _StubFrameProvider()
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=StubVlmClient(responder=responder),
|
||||
# Disable the rephrasings sub-prompt so the test's only video-bearing
|
||||
# call is the subtask one — keeps the assertions below focused on
|
||||
# ``_generate_subtasks`` rather than fighting the order of unrelated
|
||||
# text-only Module-1 sub-prompts.
|
||||
config=PlanConfig(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
# Find the call carrying the subtask prompt rather than blindly taking
|
||||
# captured[0] — Module 1 issues several sub-prompts and their order is
|
||||
# not part of the contract.
|
||||
assert captured, "no VLM calls made"
|
||||
|
||||
def _prompt_text(messages):
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
subtask_calls = [m for m in captured if "atomic subtasks" in _prompt_text(m)]
|
||||
assert len(subtask_calls) == 1, "expected exactly one subtask-prompt VLM call"
|
||||
content = subtask_calls[0][0]["content"]
|
||||
video_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "video"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(video_blocks) == 1, f"expected exactly 1 video block, got {content}"
|
||||
assert image_blocks == [], "subtask prompt must not mix image blocks with the video block"
|
||||
assert len(text_blocks) == 1
|
||||
# video block must wrap a list of frames covering the episode
|
||||
assert isinstance(video_blocks[0]["video"], list)
|
||||
assert len(video_blocks[0]["video"]) <= 5
|
||||
# provider is called with target_count = min(duration * fps, max). With
|
||||
# fps=10 on a ~1s episode that requests >max, so max=5 wins.
|
||||
assert provider.video_calls and provider.video_calls[0][0] == record.episode_index
|
||||
assert provider.video_calls[0][1] <= 5
|
||||
|
||||
|
||||
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
}
|
||||
provider = _StubFrameProvider()
|
||||
module = GeneralVqaModule(
|
||||
vlm=_spy_responder(captured, payload),
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=1),
|
||||
seed=0,
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
assert captured, "no VLM calls made"
|
||||
for messages in captured:
|
||||
content = messages[0]["content"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||
assert image_blocks[0]["image"] is provider.sentinel
|
||||
assert len(text_blocks) == 1
|
||||
# provider was called once per emission per camera with the exact emission timestamp
|
||||
for ep_idx, ts_tuple, camera in provider.calls:
|
||||
assert ep_idx == record.episode_index
|
||||
assert len(ts_tuple) == 1
|
||||
assert ts_tuple[0] in record.frame_timestamps
|
||||
assert camera in provider.cameras
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "Where is the cup?",
|
||||
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
frame_provider=_StubFrameProvider(),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("vqa")
|
||||
for row in rows:
|
||||
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||
decoded = json.loads(row["content"])
|
||||
assert "detections" in decoded
|
||||
175
tests/annotations/test_pipeline_recipe_render.py
Normal file
175
tests/annotations/test_pipeline_recipe_render.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end smoke: pipeline output → PR 1 canonical recipe rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
AnnotationPipelineConfig,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
def _build_pr1_style_blend_recipe() -> TrainingRecipe:
|
||||
"""Inline blend recipe that consumes every style this pipeline produces.
|
||||
|
||||
PR 1 used to ship ``src/lerobot/configs/recipes/pi05_hirobot.yaml`` as
|
||||
a canonical example, but that file was dropped during PR 1 review. The
|
||||
cross-PR contract this test guards is "the recipe DSL can render
|
||||
non-empty messages from pipeline output", which doesn't require a
|
||||
specific YAML — so we build the equivalent blend in code.
|
||||
"""
|
||||
return TrainingRecipe(
|
||||
blend={
|
||||
"low_level_execution": TrainingRecipe(
|
||||
weight=0.35,
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
|
||||
stream="high_level",
|
||||
),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
],
|
||||
),
|
||||
"user_interjection_response": TrainingRecipe(
|
||||
weight=0.16,
|
||||
bindings={
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${interjection}",
|
||||
stream="high_level",
|
||||
if_present="interjection",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${plan}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="plan",
|
||||
tool_calls_from="speech",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_executor() -> Executor:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
|
||||
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
|
||||
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
|
||||
]
|
||||
},
|
||||
"concise hierarchical PLAN": {"plan": "1. grasp\n2. pour\n3. place"},
|
||||
"Update the memory": {"memory": "poured once"},
|
||||
"acknowledgement the robot": {"text": "Sure."},
|
||||
"ONE realistic interruption": {
|
||||
"interjection": "use less water",
|
||||
"speech": "Using less water.",
|
||||
},
|
||||
"frame-grounded visual question": {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
config = AnnotationPipelineConfig(
|
||||
plan=PlanConfig(),
|
||||
interjections=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
vqa=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
)
|
||||
return Executor(
|
||||
config=config,
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
|
||||
def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||
single_episode_root: Path,
|
||||
) -> None:
|
||||
executor = _build_executor()
|
||||
summary = executor.run(single_episode_root)
|
||||
# validator may emit warnings but no errors for the synthetic fixture
|
||||
assert summary.validation_report.ok, summary.validation_report.summary()
|
||||
|
||||
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent_lists = table.column("language_persistent").to_pylist()
|
||||
events_lists = table.column("language_events").to_pylist()
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
|
||||
recipe = _build_pr1_style_blend_recipe()
|
||||
|
||||
rendered_any = False
|
||||
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=float(ts),
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
if result["messages"]:
|
||||
rendered_any = True
|
||||
assert result["target_message_indices"]
|
||||
break
|
||||
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
|
||||
|
||||
# Sanity: speech atom appears in events column intact
|
||||
flat_events = [r for ev in events_lists for r in ev]
|
||||
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
|
||||
assert speech_rows
|
||||
say = speech_rows[0]["tool_calls"][0]
|
||||
assert say["function"]["name"] == "say"
|
||||
assert isinstance(say["function"]["arguments"]["text"], str)
|
||||
# PR 2 no longer writes a ``tools`` column — the say schema lives as a
|
||||
# constant (``SAY_TOOL_SCHEMA``) so PR 1's row struct is the single
|
||||
# source of truth for the v3.1 schema.
|
||||
assert "tools" not in table.column_names
|
||||
125
tests/annotations/test_validator.py
Normal file
125
tests/annotations/test_validator.py
Normal file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Validator behavior tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.writer import speech_atom
|
||||
|
||||
|
||||
def _validate(root: Path, staging_dir: Path):
|
||||
records = list(iter_episodes(root))
|
||||
return StagingValidator().validate(records, staging_dir)
|
||||
|
||||
|
||||
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"vqa",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": 9.999, # not on any 10 fps frame
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("does not match any source frame timestamp" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
# interjection at 0.3s with NO paired speech
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip it",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("paired speech" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do x",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do x",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
speech_atom(0.4, "Replanning."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "replan",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
# missing co-timestamped plan refresh at 0.4s → error
|
||||
assert not report.ok
|
||||
assert any("co-timestamped plan update" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"plan",
|
||||
[
|
||||
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("plan emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
412
tests/annotations/test_vocabulary.py
Normal file
412
tests/annotations/test_vocabulary.py
Normal file
@@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Vocabulary-discovery phase (phase 0) tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
PlanConfig,
|
||||
VocabularyConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import PlanSubtasksMemoryModule
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.vocabulary import (
|
||||
Vocabulary,
|
||||
VocabularyDiscoveryModule,
|
||||
load_vocabulary,
|
||||
save_vocabulary,
|
||||
vocabulary_path,
|
||||
)
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
_CANONICAL_SUBTASKS = (
|
||||
"grasp blue cube",
|
||||
"place blue cube in box",
|
||||
"retract arm",
|
||||
)
|
||||
_CANONICAL_MEMORY = (
|
||||
"I picked up the blue cube.",
|
||||
"I placed the blue cube in the box.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vocabulary dataclass + on-disk round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_vocabulary_roundtrip(tmp_path: Path) -> None:
|
||||
vocab = Vocabulary(
|
||||
subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY
|
||||
)
|
||||
save_path = save_vocabulary(tmp_path, vocab)
|
||||
assert save_path == vocabulary_path(tmp_path)
|
||||
assert save_path.exists()
|
||||
|
||||
loaded = load_vocabulary(tmp_path)
|
||||
assert loaded is not None
|
||||
assert loaded.subtasks == _CANONICAL_SUBTASKS
|
||||
assert loaded.memory_milestones == _CANONICAL_MEMORY
|
||||
|
||||
|
||||
def test_vocabulary_load_missing_returns_none(tmp_path: Path) -> None:
|
||||
assert load_vocabulary(tmp_path) is None
|
||||
|
||||
|
||||
def test_vocabulary_load_malformed_returns_none(tmp_path: Path) -> None:
|
||||
path = vocabulary_path(tmp_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("{ not valid json", encoding="utf-8")
|
||||
assert load_vocabulary(tmp_path) is None
|
||||
|
||||
|
||||
def test_vocabulary_load_empty_payload_returns_none(tmp_path: Path) -> None:
|
||||
path = vocabulary_path(tmp_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps({"subtasks": [], "memory_milestones": []}), encoding="utf-8")
|
||||
assert load_vocabulary(tmp_path) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_vocabulary_discovery_calls_vlm_and_returns_vocab(
|
||||
fixture_dataset_root: Path,
|
||||
) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"canonical vocabulary": {
|
||||
"subtasks": list(_CANONICAL_SUBTASKS),
|
||||
"memory_milestones": list(_CANONICAL_MEMORY),
|
||||
}
|
||||
}
|
||||
)
|
||||
module = VocabularyDiscoveryModule(vlm=vlm, config=VocabularyConfig(sample_episodes=2))
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
vocab = module.discover(records)
|
||||
assert vocab is not None
|
||||
assert vocab.subtasks == _CANONICAL_SUBTASKS
|
||||
assert vocab.memory_milestones == _CANONICAL_MEMORY
|
||||
|
||||
|
||||
def test_vocabulary_discovery_reuses_existing(fixture_dataset_root: Path) -> None:
|
||||
"""``reuse_existing=True`` short-circuits the VLM call entirely."""
|
||||
|
||||
def _explode(_messages): # pragma: no cover - must not be called
|
||||
raise AssertionError("VLM should not be invoked when reusing existing vocabulary")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
vlm = StubVlmClient(responder=_explode)
|
||||
module = VocabularyDiscoveryModule(
|
||||
vlm=vlm, config=VocabularyConfig(reuse_existing=True)
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
existing = Vocabulary(subtasks=("a", "b"), memory_milestones=("I a.",))
|
||||
vocab = module.discover(records, existing=existing)
|
||||
assert vocab is existing
|
||||
|
||||
|
||||
def test_vocabulary_discovery_empty_payload_returns_none(
|
||||
fixture_dataset_root: Path,
|
||||
) -> None:
|
||||
vlm = make_canned_responder({"canonical vocabulary": {"subtasks": [], "memory_milestones": []}})
|
||||
module = VocabularyDiscoveryModule(vlm=vlm, config=VocabularyConfig())
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
assert module.discover(records) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PlanSubtasksMemoryModule consumes the vocabulary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_plan_module_inlines_vocab_into_subtask_prompt(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
captured: list[str] = []
|
||||
|
||||
def responder(messages):
|
||||
# Find the last user text block and stash it for inspection.
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
captured.append(block.get("text", ""))
|
||||
# Return canned subtasks; pick the first two canonical strings so
|
||||
# the validator accepts them.
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp blue cube", "start": 0.0, "end": 0.4},
|
||||
{"text": "place blue cube in box", "start": 0.4, "end": 0.9},
|
||||
]
|
||||
}
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm,
|
||||
config=PlanConfig(n_task_rephrasings=0),
|
||||
vocabulary=vocab,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
# The subtask prompt (and the memory prompt) carries the canonical
|
||||
# bullet list so the VLM can't paraphrase them away.
|
||||
assert any("Canonical subtask labels:" in t for t in captured)
|
||||
assert any("grasp blue cube" in t for t in captured)
|
||||
|
||||
|
||||
def test_plan_module_accepts_article_only_difference(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Articles like 'the'/'a'/'an' are stripped during validation."""
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
def responder(_messages):
|
||||
return {
|
||||
"subtasks": [
|
||||
# Same canonical phrase modulo "the" — should be accepted.
|
||||
{"text": "grasp the blue cube", "start": 0.0, "end": 0.4},
|
||||
]
|
||||
}
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm,
|
||||
config=PlanConfig(n_task_rephrasings=0),
|
||||
vocabulary=vocab,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
|
||||
assert subtask_texts == ["grasp blue cube"]
|
||||
|
||||
|
||||
def test_plan_module_retries_when_subtask_off_vocab(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""One-shot retry replaces an off-vocab paraphrase with the canonical form."""
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def responder(messages):
|
||||
call_count["n"] += 1
|
||||
# First call: returns an off-vocab paraphrase.
|
||||
if call_count["n"] == 1:
|
||||
return {
|
||||
"subtasks": [
|
||||
# paraphrase, not in vocab
|
||||
{"text": "pick up blue cube", "start": 0.0, "end": 0.4},
|
||||
]
|
||||
}
|
||||
# Second call (the retry): should contain the correction prompt;
|
||||
# respond with the canonical phrase exactly.
|
||||
last_user_text = ""
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
last_user_text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
last_user_text = block.get("text", "")
|
||||
assert "NOT in the canonical vocabulary" in last_user_text
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp blue cube", "start": 0.0, "end": 0.4},
|
||||
]
|
||||
}
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm,
|
||||
config=PlanConfig(n_task_rephrasings=0),
|
||||
vocabulary=vocab,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
|
||||
assert subtask_texts == ["grasp blue cube"]
|
||||
# The retry must have fired exactly once.
|
||||
assert call_count["n"] == 2
|
||||
|
||||
|
||||
def test_plan_module_drops_off_vocab_subtask_after_retry(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""If the VLM stays off-vocab even after the retry, the bad span is dropped."""
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def responder(_messages):
|
||||
call_count["n"] += 1
|
||||
# Both calls return the same off-vocab span — the model can't
|
||||
# be corrected. The second call also returns one in-vocab span
|
||||
# so the episode isn't empty; this lets us check that the
|
||||
# off-vocab span is dropped without affecting the in-vocab one.
|
||||
if call_count["n"] == 1:
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4},
|
||||
{"text": "grasp blue cube", "start": 0.4, "end": 0.9},
|
||||
]
|
||||
}
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "perform a fancy macarena dance", "start": 0.0, "end": 0.4},
|
||||
{"text": "grasp blue cube", "start": 0.4, "end": 0.9},
|
||||
]
|
||||
}
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm,
|
||||
config=PlanConfig(n_task_rephrasings=0),
|
||||
vocabulary=vocab,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
|
||||
# Retry fired exactly once; bad span dropped, good span kept.
|
||||
assert call_count["n"] == 2
|
||||
assert subtask_texts == ["grasp blue cube"]
|
||||
|
||||
|
||||
def test_plan_module_bumps_collocated_subtasks_to_distinct_frames(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Two subtasks whose starts snap to the same frame get split onto two frames.
|
||||
|
||||
Without this guard, both spans would emit ``style=subtask`` rows at the
|
||||
identical persistent timestamp; the training-time renderer's
|
||||
``active_at(t, style=subtask)`` then raises an ambiguity error.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
def responder(_messages):
|
||||
# Two canonical labels with starts within one frame of each other —
|
||||
# both snap to the same source frame, so the dedupe pass must bump
|
||||
# the later one to the next frame.
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp blue cube", "start": 0.40, "end": 0.42},
|
||||
{"text": "place blue cube in box", "start": 0.41, "end": 0.50},
|
||||
]
|
||||
}
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm,
|
||||
config=PlanConfig(n_task_rephrasings=0),
|
||||
vocabulary=vocab,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
subtask_rows = [r for r in rows if r["style"] == "subtask"]
|
||||
# Both subtasks present, both on distinct timestamps.
|
||||
assert len(subtask_rows) == 2
|
||||
timestamps = [r["timestamp"] for r in subtask_rows]
|
||||
assert len(set(timestamps)) == 2, f"subtask timestamps collide: {timestamps}"
|
||||
# Order preserved: the chronologically earlier span keeps the earlier
|
||||
# frame, the later one was bumped onto the next available frame.
|
||||
assert subtask_rows[0]["content"] == "grasp blue cube"
|
||||
assert subtask_rows[1]["content"] == "place blue cube in box"
|
||||
assert subtask_rows[1]["timestamp"] > subtask_rows[0]["timestamp"]
|
||||
|
||||
|
||||
def test_plan_module_empty_when_all_off_vocab_after_retry(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""All-off-vocab spans → episode comes out empty (no silent fuzzy snap)."""
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
def responder(_messages):
|
||||
# Returns the same off-vocab spans on both attempts.
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "make a smoothie", "start": 0.0, "end": 0.4},
|
||||
{"text": "consult the wizard", "start": 0.4, "end": 0.9},
|
||||
]
|
||||
}
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
vocab = Vocabulary(subtasks=_CANONICAL_SUBTASKS, memory_milestones=_CANONICAL_MEMORY)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm,
|
||||
config=PlanConfig(n_task_rephrasings=0),
|
||||
vocabulary=vocab,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
|
||||
# No subtask gets fabricated — better to leave the episode empty
|
||||
# so the operator notices the vocabulary gap than to silently
|
||||
# warp the labels.
|
||||
assert subtask_texts == []
|
||||
|
||||
|
||||
def test_plan_module_without_vocab_passes_through(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""No vocabulary configured → original free-form behavior is preserved."""
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
def responder(_messages):
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "any free-form text the VLM wants", "start": 0.0, "end": 1.0},
|
||||
]
|
||||
}
|
||||
|
||||
vlm = StubVlmClient(responder=responder)
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=vlm, config=PlanConfig(n_task_rephrasings=0)
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
subtask_texts = [r["content"] for r in rows if r["style"] == "subtask"]
|
||||
assert subtask_texts == ["any free-form text the VLM wants"]
|
||||
350
tests/annotations/test_writer.py
Normal file
350
tests/annotations/test_writer.py
Normal file
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Writer correctness tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
LanguageColumnsWriter,
|
||||
speech_atom,
|
||||
)
|
||||
|
||||
|
||||
def _stage_episode(
|
||||
staging_dir: Path,
|
||||
episode_index: int,
|
||||
*,
|
||||
plan: list[dict] | None = None,
|
||||
interjections: list[dict] | None = None,
|
||||
vqa: list[dict] | None = None,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, episode_index)
|
||||
if plan is not None:
|
||||
staging.write("plan", plan)
|
||||
if interjections is not None:
|
||||
staging.write("interjections", interjections)
|
||||
if vqa is not None:
|
||||
staging.write("vqa", vqa)
|
||||
|
||||
|
||||
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Every frame in an episode has a byte-identical persistent list."""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. wipe\n2. dry",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wiped the counter",
|
||||
"style": "memory",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent = table.column("language_persistent").to_pylist()
|
||||
first = persistent[0]
|
||||
assert first # non-empty
|
||||
for row in persistent:
|
||||
assert row == first, "persistent slice must be byte-identical across all frames"
|
||||
|
||||
|
||||
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
interjections=[
|
||||
speech_atom(0.0, "Got it."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip the dishes",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.5, "Skipping the dishes."),
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
events = table.column("language_events").to_pylist()
|
||||
for ts, ev in zip(timestamps, events, strict=True):
|
||||
if abs(ts - 0.0) < 1e-9:
|
||||
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
|
||||
elif abs(ts - 0.5) < 1e-9:
|
||||
assert any(r.get("style") == "interjection" for r in ev), ev
|
||||
assert any(r.get("style") is None for r in ev), ev
|
||||
else:
|
||||
assert ev == []
|
||||
|
||||
|
||||
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do X",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "did X",
|
||||
"style": "memory",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
interjections=[
|
||||
speech_atom(0.0, "OK"),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "wait",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.2,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.2, "Waiting"),
|
||||
],
|
||||
vqa=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "where is the cup?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"camera": "observation.images.front",
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(
|
||||
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
|
||||
sort_keys=True,
|
||||
),
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"camera": "observation.images.front",
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
|
||||
persistent = table.column("language_persistent").to_pylist()[0]
|
||||
persistent_styles = {r["style"] for r in persistent}
|
||||
assert persistent_styles == {"subtask", "plan", "memory"}
|
||||
|
||||
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
|
||||
event_styles = {r.get("style") for r in all_events}
|
||||
assert event_styles == {None, "interjection", "vqa"}
|
||||
|
||||
|
||||
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
writer = LanguageColumnsWriter()
|
||||
writer.write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
|
||||
table_a = pq.read_table(path)
|
||||
assert "subtask_index" not in table_a.column_names
|
||||
assert "language_persistent" in table_a.column_names
|
||||
assert "language_events" in table_a.column_names
|
||||
# The writer no longer emits a dataset-level ``tools`` column; the
|
||||
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
|
||||
# so the parquet stays small and PR 2 doesn't extend PR 1's schema.
|
||||
assert "tools" not in table_a.column_names
|
||||
|
||||
# second pass — must produce identical bytes for the language columns
|
||||
records_again = list(iter_episodes(fixture_dataset_root))
|
||||
writer.write_all(records_again, staging_dir, fixture_dataset_root)
|
||||
table_b = pq.read_table(path)
|
||||
assert (
|
||||
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
|
||||
)
|
||||
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
|
||||
"""``_normalize_persistent_row`` must reject any non-persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
|
||||
|
||||
with pytest.raises(ValueError, match="non-persistent style"):
|
||||
_normalize_persistent_row(
|
||||
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
|
||||
)
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_event_style() -> None:
|
||||
"""``_normalize_event_row`` must reject any persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
|
||||
|
||||
|
||||
def test_say_tool_schema_constant_is_well_formed() -> None:
|
||||
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
|
||||
``tools`` column — chat-template consumers import them directly.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
DEFAULT_TOOLS,
|
||||
SAY_TOOL_SCHEMA,
|
||||
)
|
||||
|
||||
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
|
||||
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
|
||||
params = SAY_TOOL_SCHEMA["function"]["parameters"]
|
||||
assert params["properties"]["text"]["type"] == "string"
|
||||
assert params["required"] == ["text"]
|
||||
|
||||
|
||||
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Re-running on a parquet that already has a legacy ``tools`` column
|
||||
must drop it cleanly so reruns converge to the v3.1 schema.
|
||||
"""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
assert "tools" not in table.column_names
|
||||
|
||||
|
||||
def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Annotated parquet columns must be declared in ``meta/info.json``.
|
||||
|
||||
``LeRobotDataset`` loads non-streaming datasets by casting parquet
|
||||
against metadata-derived HF features. If the annotation writer adds
|
||||
language columns but metadata stays stale, that cast fails with a column
|
||||
mismatch.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import load_info, load_nested_dataset
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
|
||||
|
||||
info_path = fixture_dataset_root / "meta" / "info.json"
|
||||
info = json.loads(info_path.read_text())
|
||||
info["features"] = {
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
info_path.write_text(json.dumps(info, indent=2))
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
|
||||
|
||||
synced = load_info(fixture_dataset_root)
|
||||
for key, feature in language_feature_info().items():
|
||||
assert synced["features"][key] == feature
|
||||
|
||||
hf_features = get_hf_features_from_features(synced["features"])
|
||||
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
|
||||
|
||||
assert LANGUAGE_PERSISTENT in dataset.column_names
|
||||
assert LANGUAGE_EVENTS in dataset.column_names
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
assert atom["style"] is None
|
||||
assert atom["content"] is None
|
||||
assert atom["timestamp"] == 2.5
|
||||
assert isinstance(atom["tool_calls"], list)
|
||||
call = atom["tool_calls"][0]
|
||||
assert call["type"] == "function"
|
||||
assert call["function"]["name"] == "say"
|
||||
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"
|
||||
61
tests/fixtures/dataset_factories.py
vendored
61
tests/fixtures/dataset_factories.py
vendored
@@ -552,3 +552,64 @@ def lerobot_dataset_factory(
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
||||
|
||||
def build_annotation_dataset(
|
||||
root: Path,
|
||||
episode_specs: list[tuple[int, int, str]],
|
||||
*,
|
||||
fps: int = 10,
|
||||
) -> Path:
|
||||
"""Build a minimal LeRobot-shaped dataset on disk for annotation tests.
|
||||
|
||||
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||
Each episode is written to its own
|
||||
``data/chunk-000/file-{ep:03d}.parquet`` so the writer's per-shard
|
||||
rewrite path is exercised. The dataset carries the minimum
|
||||
``meta/tasks.parquet`` + ``meta/info.json`` the reader / executor need;
|
||||
it has no videos, so the modules fall back to text-only prompts.
|
||||
|
||||
Shared by the annotation-pipeline pytest fixtures (``tests/annotations/
|
||||
conftest.py``) and the opt-in E2E smoke run so the fixture shape lives
|
||||
in exactly one place.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tasks: dict[int, str] = {}
|
||||
for episode_index, num_frames, task_text in episode_specs:
|
||||
if task_text not in tasks.values():
|
||||
tasks[len(tasks)] = task_text
|
||||
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": list(range(num_frames)),
|
||||
"timestamp": [round(i / fps, 6) for i in range(num_frames)],
|
||||
"task_index": [task_index] * num_frames,
|
||||
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||
}
|
||||
)
|
||||
frame.to_parquet(data_dir / f"file-{episode_index:03d}.parquet", index=False)
|
||||
|
||||
# Canonical tasks frame: indexed by task string with a ``task_index``
|
||||
# column, matching what ``lerobot.datasets.io_utils.load_tasks`` expects.
|
||||
tasks_df = pd.DataFrame(
|
||||
{"task_index": list(tasks.keys())},
|
||||
index=pd.Index(list(tasks.values()), name="task"),
|
||||
)
|
||||
write_tasks(tasks_df, root)
|
||||
|
||||
write_json(
|
||||
{
|
||||
"codebase_version": "v3.1",
|
||||
"fps": fps,
|
||||
"features": {},
|
||||
"total_episodes": len(episode_specs),
|
||||
},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
51
tests/scripts/test_lerobot_annotate.py
Normal file
51
tests/scripts/test_lerobot_annotate.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def test_push_to_hub_tags_uploaded_dataset_revision(tmp_path, monkeypatch):
|
||||
from lerobot.scripts.lerobot_annotate import _push_to_hub
|
||||
|
||||
root = tmp_path / "dataset"
|
||||
(root / "meta").mkdir(parents=True)
|
||||
(root / "meta" / "info.json").write_text(json.dumps({"codebase_version": "v3.0"}))
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeHfApi:
|
||||
def create_repo(self, **kwargs):
|
||||
calls["create_repo"] = kwargs
|
||||
|
||||
def upload_folder(self, **kwargs):
|
||||
calls["upload_folder"] = kwargs
|
||||
return SimpleNamespace(oid="abc123")
|
||||
|
||||
def create_tag(self, **kwargs):
|
||||
calls["create_tag"] = kwargs
|
||||
|
||||
monkeypatch.setattr("huggingface_hub.HfApi", FakeHfApi)
|
||||
|
||||
cfg = SimpleNamespace(
|
||||
repo_id="source/dataset",
|
||||
dest_repo_id="annotated/dataset",
|
||||
push_private=True,
|
||||
push_commit_message=None,
|
||||
)
|
||||
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
assert calls["create_repo"] == {
|
||||
"repo_id": "annotated/dataset",
|
||||
"repo_type": "dataset",
|
||||
"private": True,
|
||||
"exist_ok": True,
|
||||
}
|
||||
assert calls["upload_folder"]["repo_id"] == "annotated/dataset"
|
||||
assert calls["create_tag"] == {
|
||||
"repo_id": "annotated/dataset",
|
||||
"tag": "v3.0",
|
||||
"repo_type": "dataset",
|
||||
"exist_ok": True,
|
||||
"revision": "abc123",
|
||||
}
|
||||
Reference in New Issue
Block a user