mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
Compare commits
106 Commits
docs/compl
...
feat/langu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
98a519e7f2 | ||
|
|
5dbf0fac5f | ||
|
|
2bfaf44db2 | ||
|
|
1e7c0d6aa1 | ||
|
|
920c6ef5a2 | ||
|
|
c37b1fc7d0 | ||
|
|
9020635b14 | ||
|
|
471b2b1b1d | ||
|
|
a15e16c072 | ||
|
|
336af85c09 | ||
|
|
54221ceea2 | ||
|
|
369ab17110 | ||
|
|
86a7edc590 | ||
|
|
a0233f53f4 | ||
|
|
2ea0da2d9f | ||
|
|
134a707c7a | ||
|
|
ce47075d6b | ||
|
|
26013da699 | ||
|
|
f72b28738a | ||
|
|
1bd53cc7da | ||
|
|
7128bb1769 | ||
|
|
31e0c15e55 | ||
|
|
c5676ef1b3 | ||
|
|
9dfc9084e1 | ||
|
|
fd18beb3a1 | ||
|
|
965d42825f | ||
|
|
1238a0cd47 | ||
|
|
53c7641885 | ||
|
|
088c8371df | ||
|
|
3a52a18b0e | ||
|
|
dad2cf1178 | ||
|
|
bce5387e04 | ||
|
|
85576acc29 | ||
|
|
e7e5fca5de | ||
|
|
beb22afd81 | ||
|
|
d55b581ca1 | ||
|
|
24d2ffe3c6 | ||
|
|
789f29aa56 | ||
|
|
a356b12c41 | ||
|
|
e8327b8e62 | ||
|
|
c450298147 | ||
|
|
5c30b14929 | ||
|
|
8fa8323c91 | ||
|
|
73740ecf4b | ||
|
|
1b81e49214 | ||
|
|
d813c75b76 | ||
|
|
3434d2ef22 | ||
|
|
b71e10da6b | ||
|
|
0f6e3230df | ||
|
|
2f2e42c4aa | ||
|
|
5ee0104739 | ||
|
|
e064cfcb04 | ||
|
|
b3d9494831 | ||
|
|
1217fdb6f0 | ||
|
|
d0388e1142 | ||
|
|
524aa59faa | ||
|
|
27f7829b09 | ||
|
|
7f8bf108e8 | ||
|
|
855ff027f8 | ||
|
|
3b797bb118 | ||
|
|
aea04721ae | ||
|
|
ab5479129a | ||
|
|
e6d4ac6f02 | ||
|
|
5722d365c5 | ||
|
|
3d7e60cee4 | ||
|
|
7b767d4d60 | ||
|
|
f1e3ab7794 | ||
|
|
585341ba9f | ||
|
|
23ff346027 | ||
|
|
3c5cbe7af4 | ||
|
|
f2cbd97635 | ||
|
|
c06c8d594a | ||
|
|
cd495a3a9d | ||
|
|
c99ac45cd1 | ||
|
|
13aaafeae0 | ||
|
|
2129648bf4 | ||
|
|
f5cd3f6e4e | ||
|
|
ecf5766301 | ||
|
|
11597d4f71 | ||
|
|
8b9c598cf4 | ||
|
|
b325475b38 | ||
|
|
ef137ff86a | ||
|
|
c5df821a96 | ||
|
|
7ec3d7999c | ||
|
|
712d63abbd | ||
|
|
6653999983 | ||
|
|
4bdbedc9a0 | ||
|
|
e240305e8e | ||
|
|
ccd189b264 | ||
|
|
ef1242bbd4 | ||
|
|
ebf4a04d41 | ||
|
|
4419b4ef1b | ||
|
|
ff06ca82d2 | ||
|
|
fcb01e73eb | ||
|
|
268f8d1f53 | ||
|
|
663fff0ae2 | ||
|
|
9d6af804bf | ||
|
|
f763f85213 | ||
|
|
e3e9374e2c | ||
|
|
c1a0c601e2 | ||
|
|
1ca38d9748 | ||
|
|
5a6aa64570 | ||
|
|
0b06790da0 | ||
|
|
b43dc39ba4 | ||
|
|
2b71221194 | ||
|
|
8833d735a1 |
6
Makefile
6
Makefile
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
|||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1
|
--eval.batch_size=1
|
||||||
|
|
||||||
|
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||||
|
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||||
|
# backend, so it does not require a real model checkpoint or GPU.
|
||||||
|
annotation-e2e:
|
||||||
|
uv run python -m tests.annotations.run_e2e_smoke
|
||||||
|
|||||||
@@ -43,6 +43,8 @@
|
|||||||
title: Language Columns and Recipes
|
title: Language Columns and Recipes
|
||||||
- local: tools
|
- local: tools
|
||||||
title: Tools
|
title: Tools
|
||||||
|
- local: annotation_pipeline
|
||||||
|
title: Annotation Pipeline
|
||||||
- local: video_encoding_parameters
|
- local: video_encoding_parameters
|
||||||
title: Video encoding parameters
|
title: Video encoding parameters
|
||||||
- local: streaming_video_encoding
|
- local: streaming_video_encoding
|
||||||
|
|||||||
198
docs/source/annotation_pipeline.mdx
Normal file
198
docs/source/annotation_pipeline.mdx
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Annotation Pipeline
|
||||||
|
|
||||||
|
`lerobot-annotate` populates the two language columns introduced by the
|
||||||
|
[Language Columns and Recipes](./language_and_recipes) page —
|
||||||
|
`language_persistent` and `language_events` — directly into
|
||||||
|
`data/chunk-*/file-*.parquet`.
|
||||||
|
|
||||||
|
## What the pipeline produces
|
||||||
|
|
||||||
|
A vocabulary-discovery phase derives a small canonical wording, then three
|
||||||
|
modules write into a per-episode staging tree, then a single writer
|
||||||
|
rewrites the data shards in place:
|
||||||
|
|
||||||
|
| Style / atom | Column | Module |
|
||||||
|
| ------------------------------------------- | --------------------- | -------------- |
|
||||||
|
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||||
|
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||||
|
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||||
|
| `task_aug` (rephrasings of canonical task) | `language_persistent` | `plan` |
|
||||||
|
| `interjection` | `language_events` | `interjections`|
|
||||||
|
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections`|
|
||||||
|
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||||
|
|
||||||
|
The `plan` module is constrained to a **canonical vocabulary** discovered
|
||||||
|
once per dataset by the `vocabulary` module (phase 0). It watches a few
|
||||||
|
sample episode videos (`--vocabulary.sample_episodes`, default `3`) and
|
||||||
|
asks the VLM to derive a small set of imperative subtask labels and
|
||||||
|
first-person memory milestones that recur across the demos. The VLM
|
||||||
|
picks the right number of entries itself based on what it sees in the
|
||||||
|
clips — short pick-and-place demos get ~6 subtask labels, longer
|
||||||
|
multi-step recipes get more. The result lands at
|
||||||
|
`meta/canonical_vocabulary.json` (human-readable / hand-editable) and
|
||||||
|
is reused on every subsequent run. The `plan` module then constrains
|
||||||
|
both subtask + memory generation to those exact strings — the
|
||||||
|
downstream low-level policy sees a small, repeatable target
|
||||||
|
distribution instead of thousands of LLM paraphrases. Disable with
|
||||||
|
`--vocabulary.enabled=False` to fall back to free-form generation.
|
||||||
|
|
||||||
|
The writer does **not** add a `tools` column to the parquet — the tool
|
||||||
|
catalog lives at `meta/info.json["tools"]` instead (see
|
||||||
|
[Tools](./tools)). After every annotation run the pipeline ensures the
|
||||||
|
canonical `say` schema is present in that list, preserving any tools the
|
||||||
|
user pre-declared.
|
||||||
|
|
||||||
|
If you want to declare additional tools for a dataset before annotation
|
||||||
|
runs, edit `meta/info.json["tools"]` directly — the pipeline preserves
|
||||||
|
anything already there. Implementations of those tools live under
|
||||||
|
`src/lerobot/tools/`; one file per tool, registered via
|
||||||
|
`TOOL_REGISTRY`. See the [Tools](./tools) doc for the authoring guide.
|
||||||
|
|
||||||
|
## Running locally
|
||||||
|
|
||||||
|
Install the extra and invoke the console script. Episode-level
|
||||||
|
concurrency comes from `--executor.episode_parallelism` (default 16);
|
||||||
|
that is the only knob the in-process executor exposes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync --extra annotations
|
||||||
|
uv run lerobot-annotate \
|
||||||
|
--root=/path/to/dataset \
|
||||||
|
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
The pipeline attaches actual camera footage to every `plan` /
|
||||||
|
`interjections` / `vqa` prompt by default, decoded from the dataset's
|
||||||
|
first `observation.images.*` stream. Override with
|
||||||
|
`--vlm.camera_key=observation.images.<name>` to pin a specific
|
||||||
|
viewpoint. Datasets with no video tracks fall back to text-only prompts
|
||||||
|
automatically.
|
||||||
|
|
||||||
|
**The `plan` module sees the whole episode as one video block.** Subtask
|
||||||
|
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||||
|
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||||
|
and decides where to cut. There is no keyframe stride or count knob —
|
||||||
|
`--plan.max_video_frames` (default 128) only caps the frames packed
|
||||||
|
into the video block as a model-capacity bound. The `interjections`
|
||||||
|
module attaches a short window of frames straddling the interjection
|
||||||
|
timestamp. The `vqa` module grounds each VQA pair on a single frame —
|
||||||
|
its `--vqa.K` knob sets how many consecutive frames each emission tick
|
||||||
|
anchors, and every anchored frame gets its own VQA pair on that one
|
||||||
|
frame (there is no per-pair frame window).
|
||||||
|
|
||||||
|
## Running on Hugging Face Jobs
|
||||||
|
|
||||||
|
Distributed annotation is delegated to
|
||||||
|
[Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). The repo
|
||||||
|
ships a launcher script you copy and edit for your dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||||
|
```
|
||||||
|
|
||||||
|
[`examples/annotations/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||||
|
spawns one `h200x2` job that:
|
||||||
|
|
||||||
|
1. installs the branch under test plus the annotation extras,
|
||||||
|
2. boots two vllm servers (one per GPU) for the chosen model,
|
||||||
|
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||||
|
via `lerobot-annotate`,
|
||||||
|
4. uploads the annotated dataset to `--push_to_hub`.
|
||||||
|
|
||||||
|
To target a different dataset, model, or hub repo, edit the `CMD` block
|
||||||
|
inside the script — every flag in there maps directly onto a CLI flag of
|
||||||
|
`lerobot-annotate` (see `lerobot-annotate --help` for the full list).
|
||||||
|
|
||||||
|
## Style-to-recipe consumer mapping
|
||||||
|
|
||||||
|
The pipeline's outputs are designed to be consumed by recipes (see
|
||||||
|
[Language Columns and Recipes](./language_and_recipes)) — typically:
|
||||||
|
|
||||||
|
- low-level / high-level / memory-update branches consume
|
||||||
|
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||||
|
- An interjection-response branch consumes `interjection` events plus
|
||||||
|
the paired speech atom (merged into one assistant target turn via
|
||||||
|
`tool_calls_from`) and the same-timestamp `plan` refresh.
|
||||||
|
- A VQA branch consumes the `(vqa, user)` and `(vqa, assistant)` pairs
|
||||||
|
from `language_events`.
|
||||||
|
|
||||||
|
## Why the design splits state from events
|
||||||
|
|
||||||
|
Two things drive the scope:
|
||||||
|
|
||||||
|
1. **Persistent state vs exact-event split.** Persistent rows
|
||||||
|
(`subtask`, `plan`, `memory`) broadcast per episode and answer "what
|
||||||
|
state is in force at this frame?". Event rows (`interjection`, `vqa`,
|
||||||
|
speech) only appear on the exact frame whose timestamp matches the
|
||||||
|
emission. The pipeline writes timestamps taken straight from the
|
||||||
|
source parquet — no floating-point recomputation.
|
||||||
|
2. **One Qwen-VL pass.** All three modules share a single VLM client
|
||||||
|
(vLLM if available, transformers fallback) so the cost is one model
|
||||||
|
load per dataset, not three.
|
||||||
|
|
||||||
|
## Module independence and staged reruns
|
||||||
|
|
||||||
|
Each module writes its raw output to
|
||||||
|
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||||
|
prompt iteration cheap — re-running one module overwrites only its own
|
||||||
|
JSONL file before the writer composes the final parquet. Modules can be
|
||||||
|
disabled via `--plan.enabled=false` (and likewise `--interjections.enabled`
|
||||||
|
/ `--vqa.enabled`) to
|
||||||
|
test them in isolation.
|
||||||
|
|
||||||
|
## Validation/report checks before final write
|
||||||
|
|
||||||
|
Before the writer runs, `StagingValidator` checks:
|
||||||
|
|
||||||
|
- exact frame-timestamp alignment for every event row;
|
||||||
|
- no orphan speech / interjection pairs;
|
||||||
|
- `plan` is refreshed at every interjection timestamp;
|
||||||
|
- `memory` rows fall on subtask boundaries (warning, not error);
|
||||||
|
- VQA assistant `content` parses as JSON in one of the
|
||||||
|
bbox / keypoint / count / attribute / spatial shapes;
|
||||||
|
- every row routes to the column dictated by `column_for_style(style)`.
|
||||||
|
|
||||||
|
Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||||
|
|
||||||
|
## Paper inspirations per module
|
||||||
|
|
||||||
|
- **`plan` module — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||||
|
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||||
|
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||||
|
what" detail.
|
||||||
|
- **`plan` module — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||||
|
compression directive: keep only minimal relevant information; functional
|
||||||
|
outcomes preserved, specific attributes dropped.
|
||||||
|
- **`interjections` module.** Hi Robot scenario taxonomy: negative task,
|
||||||
|
situated correction, specific constraint, preference. Speech is a
|
||||||
|
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||||
|
arguments:{text:...}}}]`).
|
||||||
|
- **`vqa` module.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||||
|
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||||
|
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
|
||||||
|
multi-abstraction grounding. Pi0.7 also grounds answers across
|
||||||
|
multiple abstraction levels.
|
||||||
|
|
||||||
|
Future maintainers should adjust the prompt templates in
|
||||||
|
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
|
||||||
|
references rather than rewriting from scratch.
|
||||||
|
|
||||||
|
## Compute and list-size estimates
|
||||||
|
|
||||||
|
Per episode, the pipeline issues O(`max_steps`) `plan`-module calls,
|
||||||
|
O(`max_interjections_per_episode`) `interjections`-module calls, and
|
||||||
|
O(`vqa_emission_hz × episode_seconds`) `vqa`-module calls. With defaults
|
||||||
|
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||||
|
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||||
|
KB at most (parquet dictionary-encodes one entry per episode);
|
||||||
|
`language_events` is empty on most frames and is bounded by the number of
|
||||||
|
emissions, not `num_frames × num_emissions`.
|
||||||
|
|
||||||
|
## Reproducibility via seed and prompt hashes
|
||||||
|
|
||||||
|
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
|
||||||
|
timestamps and VQA question types. Combined with the deterministic prompt
|
||||||
|
templates checked into `prompts/`, two runs at the same seed against the
|
||||||
|
same dataset and the same model checkpoint produce byte-identical staging
|
||||||
|
artifacts. Prompt edits are recorded by file hash; future tooling can pin
|
||||||
|
expected `(seed, prompt_hash)` pairs into the dataset card.
|
||||||
78
examples/annotations/run_hf_job.py
Normal file
78
examples/annotations/run_hf_job.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||||
|
|
||||||
|
Spawns one ``h200x2`` job that:
|
||||||
|
|
||||||
|
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||||
|
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||||
|
3. runs the plan / interjections / vqa modules across the dataset
|
||||||
|
in free-form mode (each episode generates its own subtasks +
|
||||||
|
memory),
|
||||||
|
4. uploads the annotated dataset to ``--dest_repo_id`` (when set)
|
||||||
|
or back to ``--repo_id``.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||||
|
|
||||||
|
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import get_token, run_job
|
||||||
|
|
||||||
|
token = os.environ.get("HF_TOKEN") or get_token()
|
||||||
|
if not token:
|
||||||
|
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||||
|
|
||||||
|
CMD = (
|
||||||
|
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||||
|
"pip install --no-deps "
|
||||||
|
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||||
|
"pip install --upgrade-strategy only-if-needed "
|
||||||
|
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||||
|
"openai && "
|
||||||
|
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||||
|
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||||
|
"lerobot-annotate "
|
||||||
|
"--repo_id=pepijn223/robocasa_smoke_2atomic_v3 "
|
||||||
|
"--dest_repo_id=pepijn223/robocasa_smoke_2atomic_v3_ann "
|
||||||
|
"--push_to_hub=true "
|
||||||
|
"--vlm.backend=openai "
|
||||||
|
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||||
|
"--vlm.parallel_servers=2 "
|
||||||
|
"--vlm.num_gpus=2 "
|
||||||
|
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||||
|
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||||
|
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||||
|
"--vlm.serve_ready_timeout_s=1800 "
|
||||||
|
"--vlm.client_concurrency=128 "
|
||||||
|
"--vlm.max_new_tokens=512 "
|
||||||
|
"--vlm.temperature=0.7 "
|
||||||
|
"--executor.episode_parallelism=16 "
|
||||||
|
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||||
|
"--vlm.camera_key=observation.images.robot0_agentview_right "
|
||||||
|
# Phase 1 — plan module (subtasks + plan + memory + task_aug).
|
||||||
|
"--plan.frames_per_second=1.0 "
|
||||||
|
"--plan.use_video_url=true "
|
||||||
|
"--plan.use_video_url_fps=1.0 "
|
||||||
|
"--plan.derive_task_from_video=always "
|
||||||
|
"--plan.task_aug_axes.enabled=true "
|
||||||
|
"--plan.action_records.enabled=true "
|
||||||
|
# Phase 2 — interjections + speech.
|
||||||
|
"--interjections.max_interjections_per_episode=6 "
|
||||||
|
# Phase 4 — general VQA.
|
||||||
|
"--vqa.K=1 "
|
||||||
|
"--vqa.vqa_emission_hz=1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
job = run_job(
|
||||||
|
image="vllm/vllm-openai:latest",
|
||||||
|
command=["bash", "-c", CMD],
|
||||||
|
flavor="h200x2",
|
||||||
|
secrets={"HF_TOKEN": token},
|
||||||
|
timeout="2h",
|
||||||
|
)
|
||||||
|
print(f"Job URL: {job.url}")
|
||||||
|
print(f"Job ID: {job.id}")
|
||||||
@@ -219,6 +219,18 @@ hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.
|
|||||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||||
|
|
||||||
|
# Annotation pipeline (lerobot-annotate). vllm is the preferred backend
|
||||||
|
# on Linux, with a transformers fallback elsewhere; openai is the default
|
||||||
|
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
|
||||||
|
# ``transformers serve`` / hosted endpoints). Distributed execution is
|
||||||
|
# delegated to Hugging Face Jobs (see examples/annotations/run_hf_job.py).
|
||||||
|
annotations = [
|
||||||
|
"lerobot[dataset]",
|
||||||
|
"lerobot[transformers-dep]",
|
||||||
|
"openai>=1.40,<2.0",
|
||||||
|
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||||
|
]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||||
@@ -309,6 +321,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
|||||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
|
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
@@ -327,7 +340,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
|||||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
lerobot = ["envs/*.json"]
|
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
@@ -388,7 +401,7 @@ exclude_dirs = [
|
|||||||
"benchmarks",
|
"benchmarks",
|
||||||
"src/lerobot/datasets/push_dataset_to_hub",
|
"src/lerobot/datasets/push_dataset_to_hub",
|
||||||
]
|
]
|
||||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
skips = ["B101", "B311", "B404", "B603", "B607", "B615"]
|
||||||
|
|
||||||
[tool.typos]
|
[tool.typos]
|
||||||
default.extend-ignore-re = [
|
default.extend-ignore-re = [
|
||||||
|
|||||||
15
src/lerobot/annotations/__init__.py
Normal file
15
src/lerobot/annotations/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
36
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
36
src/lerobot/annotations/steerable_pipeline/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||||
|
``language_events`` columns for LeRobot datasets.
|
||||||
|
|
||||||
|
The pipeline is decomposed into three independently runnable modules whose
|
||||||
|
outputs are staged per-episode before a final parquet rewrite:
|
||||||
|
|
||||||
|
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||||
|
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||||
|
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .config import AnnotationPipelineConfig
|
||||||
|
from .validator import StagingValidator, ValidationReport
|
||||||
|
from .writer import LanguageColumnsWriter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnnotationPipelineConfig",
|
||||||
|
"LanguageColumnsWriter",
|
||||||
|
"StagingValidator",
|
||||||
|
"ValidationReport",
|
||||||
|
]
|
||||||
345
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
345
src/lerobot/annotations/steerable_pipeline/config.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlanConfig:
|
||||||
|
"""``plan`` module: plan + subtasks + memory + task augmentation.
|
||||||
|
|
||||||
|
The ``plan`` module attaches the whole episode as one Qwen-VL video
|
||||||
|
block; ``max_video_frames`` only caps the frames packed in (a
|
||||||
|
model-capacity bound, not an annotation-logic knob).
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
# Number of ``task_aug`` rephrasings emitted at ``t=0``. The renderer's
|
||||||
|
# ``${task}`` binding rotates among them per ``sample_idx``. ``0`` disables.
|
||||||
|
n_task_rephrasings: int = 10
|
||||||
|
|
||||||
|
# When to derive the task from the video instead of using
|
||||||
|
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
|
||||||
|
# missing canonical task), or ``always``. The derived task replaces the
|
||||||
|
# canonical one for every ``plan``-module prompt; ``meta/tasks.parquet``
|
||||||
|
# is never modified.
|
||||||
|
derive_task_from_video: str = "if_short"
|
||||||
|
derive_task_min_words: int = 3
|
||||||
|
|
||||||
|
# Frame sampling for the subtask-decomposition prompt.
|
||||||
|
frames_per_second: float = 1.0
|
||||||
|
max_video_frames: int = 128
|
||||||
|
|
||||||
|
min_subtask_seconds: float = 1.5
|
||||||
|
plan_max_steps: int = 8
|
||||||
|
|
||||||
|
# When True (and backend supports it, e.g. ``openai``), the ``plan``
|
||||||
|
# module sends a ``video_url`` block pointing at a per-episode mp4
|
||||||
|
# subclip and lets the server sample frames at ``use_video_url_fps``.
|
||||||
|
use_video_url: bool = False
|
||||||
|
use_video_url_fps: float = 1.0
|
||||||
|
|
||||||
|
# Structured per-subtask action records (Phase 1a + 1b, inspired by
|
||||||
|
# EgoMimic's annotator form). For each generated subtask span, the
|
||||||
|
# VLM extracts a typed record (verb / object / arm / grasp_type /
|
||||||
|
# destination / mistake). A deterministic Python template renders
|
||||||
|
# that record back to canonical subtask text — reducing the VLM's
|
||||||
|
# "creative" surface to just the perception step. See
|
||||||
|
# ``ActionRecordsConfig`` for details. Off by default (back-compat).
|
||||||
|
action_records: "ActionRecordsConfig" = field(default_factory=lambda: ActionRecordsConfig())
|
||||||
|
|
||||||
|
# Structured 5-axis augmentation taxonomy for the t=0 task variants
|
||||||
|
# (replaces the free-form ``n_task_rephrasings`` flow when enabled).
|
||||||
|
# Mirrors EgoMimic's ``augment_prompt.txt`` taxonomy: instead of N
|
||||||
|
# free-form rephrasings, the VLM produces variants along named
|
||||||
|
# axes (synonym / omit_arm / omit_orientation / omit_grasp_method /
|
||||||
|
# combined). Off by default (back-compat).
|
||||||
|
task_aug_axes: "TaskAugAxesConfig" = field(default_factory=lambda: TaskAugAxesConfig())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActionRecordsConfig:
|
||||||
|
"""Structured per-subtask action record extraction.
|
||||||
|
|
||||||
|
When ``enabled=True``, after the existing subtask-span generation in
|
||||||
|
``plan_subtasks_memory.py``, the module makes one extra VLM call per
|
||||||
|
subtask to extract a typed record::
|
||||||
|
|
||||||
|
{
|
||||||
|
"verb": "pick" | "place" | "press" | ..., # closed vocabulary
|
||||||
|
"object": "<canonical_object_name>",
|
||||||
|
"arm": "left" | "right" | "both" | null,
|
||||||
|
"grasp_type": "pinch" | "wrap" | "hook" | ... | null,
|
||||||
|
"destination": "<canonical_destination>" | null,
|
||||||
|
"mistake": "<short text>" | null,
|
||||||
|
}
|
||||||
|
|
||||||
|
A deterministic Python template then renders the record back to
|
||||||
|
canonical subtask text (e.g. ``pick blue cube with left arm using
|
||||||
|
pinch grip``). When ``replace_subtask_text=True`` (default), the
|
||||||
|
rendered text REPLACES the VLM's free-form subtask text — eliminating
|
||||||
|
cross-episode phrasing drift. When ``emit_record_row=True``
|
||||||
|
(default), the structured record is also emitted as a row with
|
||||||
|
``style="action_record"`` so downstream consumers can train on the
|
||||||
|
typed schema directly.
|
||||||
|
|
||||||
|
Cost: one extra VLM call per subtask. For an 8-subtask episode this
|
||||||
|
means ~8x more VLM calls in the plan module — still cheap relative
|
||||||
|
to the action-expert training cost, but worth knowing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
# When True, replace the VLM-generated subtask text with the
|
||||||
|
# deterministic template's rendering of the structured record.
|
||||||
|
# Strongly recommended — it's the whole point of the structured
|
||||||
|
# intermediate. Set False to keep both representations side by side.
|
||||||
|
replace_subtask_text: bool = True
|
||||||
|
|
||||||
|
# When True, emit a separate row with ``style="action_record"`` and
|
||||||
|
# ``content=json.dumps(record)`` at the subtask's start timestamp.
|
||||||
|
# Lets downstream training consume the typed schema directly (e.g.
|
||||||
|
# auxiliary supervision on verb/arm/grasp classification heads).
|
||||||
|
emit_record_row: bool = True
|
||||||
|
|
||||||
|
# Frame sampling for the per-subtask VLM call (similar to the
|
||||||
|
# interjection module's window). Anchored to the subtask span.
|
||||||
|
frames_per_subtask: int = 4
|
||||||
|
|
||||||
|
# Closed verb vocabulary. The prompt instructs the VLM to pick
|
||||||
|
# exactly one. Override per-dataset (e.g. ``["pick", "place", "open",
|
||||||
|
# "close"]`` for door-only manipulation) for tighter constraint.
|
||||||
|
verb_vocabulary: tuple[str, ...] = (
|
||||||
|
"pick", "place", "push", "pull", "open", "close", "turn",
|
||||||
|
"press", "lift", "insert", "pour", "move", "reach", "grasp",
|
||||||
|
"release", "wipe", "dump",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Closed grasp-type vocabulary. ``null`` is always allowed (no
|
||||||
|
# contact / unclear). Adjust per-hardware (e.g. drop ``hook`` /
|
||||||
|
# ``key`` for parallel-jaw grippers).
|
||||||
|
grasp_vocabulary: tuple[str, ...] = (
|
||||||
|
"pinch", "wrap", "hook", "key", "lateral",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskAugAxesConfig:
|
||||||
|
"""Structured 5-axis augmentation taxonomy for t=0 task variants.
|
||||||
|
|
||||||
|
When ``enabled=True``, replaces the free-form ``n_task_rephrasings``
|
||||||
|
flow with a structured prompt that produces variants along five
|
||||||
|
named axes (mirroring EgoMimic's ``augment_prompt.txt``):
|
||||||
|
|
||||||
|
* ``synonym_paraphrase`` — different wording / verbs, all
|
||||||
|
information preserved.
|
||||||
|
* ``omit_arm`` — drop the left/right/both arm specification.
|
||||||
|
* ``omit_orientation`` — drop orientation cues (upright,
|
||||||
|
sideways, ...).
|
||||||
|
* ``omit_grasp_method`` — drop grip / grasp method specification.
|
||||||
|
* ``combined_omissions`` — combine two of the above
|
||||||
|
simultaneously.
|
||||||
|
|
||||||
|
Default counts (3+3+2+2+2 = 12 variants per task) match EgoMimic.
|
||||||
|
Axes that have nothing to omit in the source task (e.g. ``omit_arm``
|
||||||
|
when the task doesn't mention an arm) emit fewer entries rather
|
||||||
|
than pad — the prompt instructs the VLM accordingly.
|
||||||
|
|
||||||
|
Each variant is emitted as a ``task_aug`` row at ``t=0`` (same
|
||||||
|
style as the free-form variants), so the rest of the pipeline /
|
||||||
|
training recipe doesn't need to know about the taxonomy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
synonym_paraphrase: int = 3
|
||||||
|
omit_arm: int = 3
|
||||||
|
omit_orientation: int = 2
|
||||||
|
omit_grasp_method: int = 2
|
||||||
|
combined_omissions: int = 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
"""Sum of requested variants across all axes (upper bound)."""
|
||||||
|
return (
|
||||||
|
self.synonym_paraphrase
|
||||||
|
+ self.omit_arm
|
||||||
|
+ self.omit_orientation
|
||||||
|
+ self.omit_grasp_method
|
||||||
|
+ self.combined_omissions
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterjectionsConfig:
|
||||||
|
"""``interjections`` module: interjections + paired speech."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
# Each interjection emits a paired ``(interjection, speech)`` event row
|
||||||
|
# and triggers a ``plan`` refresh at the same timestamp via the
|
||||||
|
# ``plan`` module.
|
||||||
|
max_interjections_per_episode: int = 3
|
||||||
|
interjection_min_t: float = 2.0
|
||||||
|
|
||||||
|
# Visual context attached to the interjection prompt: a short window
|
||||||
|
# of frames centered on the chosen timestamp so the VLM sees the
|
||||||
|
# ongoing motion rather than a single frozen frame.
|
||||||
|
interjection_window_seconds: float = 2.0
|
||||||
|
interjection_window_frames: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VqaConfig:
|
||||||
|
"""``vqa`` module: general VQA."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
vqa_emission_hz: float = 1.0
|
||||||
|
K: int = 1
|
||||||
|
"""How many *consecutive* frames each emission tick anchors a VQA pair
|
||||||
|
to. The VLM grounds its answer (bbox / keypoint coordinates, count, …)
|
||||||
|
against the *first* anchored frame's image, so anchoring K>1 frames
|
||||||
|
copies that same answer onto later frames where the scene has already
|
||||||
|
moved — stale labels. Default ``1``: a VQA pair lands on exactly its
|
||||||
|
emission frame, no temporal smear. Raise it only to trade label
|
||||||
|
precision for more (noisier) VQA frames."""
|
||||||
|
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VlmConfig:
|
||||||
|
"""Shared Qwen-VL client configuration."""
|
||||||
|
|
||||||
|
# One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests).
|
||||||
|
# ``openai`` talks to a local OpenAI-compatible server; the CLI
|
||||||
|
# auto-spawns one when ``auto_serve=True``.
|
||||||
|
backend: str = "openai"
|
||||||
|
model_id: str = "Qwen/Qwen3.6-35B-A3B-FP8"
|
||||||
|
|
||||||
|
# OpenAI-compatible server endpoint; ``EMPTY`` works for local servers.
|
||||||
|
api_base: str = "http://localhost:8000/v1"
|
||||||
|
api_key: str = "EMPTY"
|
||||||
|
|
||||||
|
# When True with ``backend=openai``, the CLI probes ``api_base`` and
|
||||||
|
# spawns a server if none answers (default: ``transformers serve``).
|
||||||
|
# Set to False to fail fast when pointing at a remote endpoint.
|
||||||
|
auto_serve: bool = True
|
||||||
|
serve_port: int = 8000
|
||||||
|
# Override the auto-serve command. ``{port}`` is substituted per replica
|
||||||
|
# when ``parallel_servers > 1``.
|
||||||
|
serve_command: str | None = None
|
||||||
|
|
||||||
|
# Run multiple independent inference servers for round-robin client
|
||||||
|
# routing (each pinned to a GPU via ``CUDA_VISIBLE_DEVICES`` and bound
|
||||||
|
# to ``serve_port + i``). ``num_gpus=0`` means one GPU per replica.
|
||||||
|
parallel_servers: int = 1
|
||||||
|
num_gpus: int = 0
|
||||||
|
client_concurrency: int = 16
|
||||||
|
serve_ready_timeout_s: float = 600.0
|
||||||
|
|
||||||
|
max_new_tokens: int = 512
|
||||||
|
temperature: float = 0.2
|
||||||
|
json_mode: bool = True
|
||||||
|
batch_size: int = 4
|
||||||
|
tensor_parallel_size: int = 1
|
||||||
|
|
||||||
|
# Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||||
|
gpu_memory_utilization: float = 0.9
|
||||||
|
# Cap context length (None = model default). On 80 GB H100 a 30B BF16
|
||||||
|
# model often needs <= 8192 to leave KV-cache headroom.
|
||||||
|
max_model_len: int | None = None
|
||||||
|
trust_remote_code: bool = False
|
||||||
|
|
||||||
|
# Override the camera stream used for keyframe attachment. None picks
|
||||||
|
# the first ``observation.images.*`` key the dataset declares.
|
||||||
|
camera_key: str | None = None
|
||||||
|
# Forwarded as ``extra_body.chat_template_kwargs`` on every chat call;
|
||||||
|
# use to pass model-specific flags such as ``{"enable_thinking": false}``.
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutorConfig:
|
||||||
|
"""Executor settings.
|
||||||
|
|
||||||
|
Distributed execution is provided by Hugging Face Jobs (see
|
||||||
|
``examples/annotation/run_hf_job.py``); this config only controls
|
||||||
|
intra-process episode concurrency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Episodes processed concurrently within each module phase. Each
|
||||||
|
# in-flight episode dispatches 3-5 dependent VLM calls, so this is the
|
||||||
|
# main knob for saturating ``parallel_servers`` and ``client_concurrency``.
|
||||||
|
episode_parallelism: int = 16
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AnnotationPipelineConfig:
|
||||||
|
"""Top-level config for ``lerobot-annotate``.
|
||||||
|
|
||||||
|
The writer rewrites ``data/chunk-*/file-*.parquet`` in place. Multiple
|
||||||
|
revisions of the same dataset live in separate copies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Hub dataset id. Used as the download source when ``root`` is unset,
|
||||||
|
# and as the destination repo when ``push_to_hub`` is enabled and
|
||||||
|
# ``dest_repo_id`` is unset.
|
||||||
|
repo_id: str | None = None
|
||||||
|
|
||||||
|
# Optional separate Hub dataset id to push the annotated result to. When
|
||||||
|
# unset, ``push_to_hub`` uploads back to ``repo_id`` (annotate in place);
|
||||||
|
# when set, the source ``repo_id`` is left untouched.
|
||||||
|
dest_repo_id: str | None = None
|
||||||
|
|
||||||
|
root: Path | None = None
|
||||||
|
|
||||||
|
# Defaults to ``<root>/.annotate_staging/`` when unset.
|
||||||
|
staging_dir: Path | None = None
|
||||||
|
|
||||||
|
seed: int = 1729
|
||||||
|
|
||||||
|
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||||
|
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||||
|
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||||
|
|
||||||
|
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||||
|
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||||
|
|
||||||
|
skip_validation: bool = False
|
||||||
|
only_episodes: tuple[int, ...] | None = None
|
||||||
|
|
||||||
|
# Keyframe decode backend. When unset, the pipeline decodes with the
|
||||||
|
# ffmpeg CLI: it decodes AV1 and runs each decode as an isolated child
|
||||||
|
# process, which is both crash-safe and safe under the concurrent
|
||||||
|
# decode the executor performs (torchcodec is not thread-safe and
|
||||||
|
# SIGSEGVs there). Set to ``"torchcodec"`` or ``"pyav"`` to pin an
|
||||||
|
# in-process decoder when its build is known thread-safe.
|
||||||
|
video_backend: str | None = None
|
||||||
|
|
||||||
|
# When True, upload the annotated dataset to the Hugging Face Hub:
|
||||||
|
# to ``dest_repo_id`` if set, otherwise back to ``repo_id``. One of
|
||||||
|
# the two must be set for this to take effect.
|
||||||
|
push_to_hub: bool = False
|
||||||
|
push_private: bool = False
|
||||||
|
push_commit_message: str | None = None
|
||||||
|
|
||||||
|
def resolved_staging_dir(self, root: Path) -> Path:
|
||||||
|
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||||
261
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
261
src/lerobot/annotations/steerable_pipeline/executor.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""In-process executor that runs the annotation phases.
|
||||||
|
|
||||||
|
The executor plans **seven phases** in the dependency order from the plan:
|
||||||
|
|
||||||
|
phase 0: vocabulary discovery — derive a small canonical vocabulary
|
||||||
|
from the first few sample-episode videos (subtask labels +
|
||||||
|
memory milestones) and persist it next to the dataset; the
|
||||||
|
``plan`` module then constrains every per-episode generation
|
||||||
|
to those strings, so the downstream policy sees a small,
|
||||||
|
repeatable conditioning distribution
|
||||||
|
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||||
|
phase 2: ``interjections`` module (interjections + speech)
|
||||||
|
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||||
|
interjection timestamp produced by phase 2
|
||||||
|
phase 4: ``vqa`` module (VQA)
|
||||||
|
phase 5: validator
|
||||||
|
phase 6: writer
|
||||||
|
|
||||||
|
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||||
|
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||||
|
timestamps.
|
||||||
|
|
||||||
|
Distributed execution is provided by Hugging Face Jobs (see
|
||||||
|
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||||
|
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||||
|
Episode-level concurrency is controlled by
|
||||||
|
``ExecutorConfig.episode_parallelism``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import AnnotationPipelineConfig
|
||||||
|
from .reader import EpisodeRecord, iter_episodes
|
||||||
|
from .staging import EpisodeStaging
|
||||||
|
from .validator import StagingValidator
|
||||||
|
from .writer import LanguageColumnsWriter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhaseResult:
|
||||||
|
"""Summary of one pipeline phase across all episodes."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
episodes_processed: int
|
||||||
|
episodes_skipped: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineRunSummary:
|
||||||
|
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||||
|
|
||||||
|
phases: list[PhaseResult]
|
||||||
|
written_paths: list[Path]
|
||||||
|
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Executor:
|
||||||
|
"""Run all six phases over a dataset root in-process.
|
||||||
|
|
||||||
|
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||||
|
(a thread pool); cluster-level concurrency comes from running this
|
||||||
|
executor inside a Hugging Face Job. Tests construct the executor
|
||||||
|
directly with stub modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: AnnotationPipelineConfig
|
||||||
|
plan: Any # PlanSubtasksMemoryModule
|
||||||
|
interjections: Any # InterjectionsAndSpeechModule
|
||||||
|
vqa: Any # GeneralVqaModule
|
||||||
|
writer: LanguageColumnsWriter
|
||||||
|
validator: StagingValidator
|
||||||
|
|
||||||
|
def run(self, root: Path) -> PipelineRunSummary:
|
||||||
|
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||||
|
n = len(records)
|
||||||
|
if n == 0:
|
||||||
|
raise ValueError(f"No episodes found under {root}/data/")
|
||||||
|
|
||||||
|
print(f"[annotate] {n} episodes total", flush=True)
|
||||||
|
|
||||||
|
staging_dir = self.config.resolved_staging_dir(root)
|
||||||
|
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
phases: list[PhaseResult] = []
|
||||||
|
|
||||||
|
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||||
|
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||||
|
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||||
|
# reads the ``plan`` module's subtask rows from the same staging
|
||||||
|
# tree to ground the interjection prompt in the correct local subtask.
|
||||||
|
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||||
|
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||||
|
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||||
|
# Phase 4: ``vqa`` module (VQA)
|
||||||
|
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||||
|
|
||||||
|
print("[annotate] running validator...", flush=True)
|
||||||
|
report = self.validator.validate(records, staging_dir)
|
||||||
|
if not report.ok and not self.config.skip_validation:
|
||||||
|
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||||
|
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||||
|
|
||||||
|
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||||
|
written = self.writer.write_all(records, staging_dir, root)
|
||||||
|
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||||
|
|
||||||
|
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||||
|
# Idempotent and additive: existing user metadata is preserved.
|
||||||
|
self._ensure_annotation_metadata_in_info(root)
|
||||||
|
|
||||||
|
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||||
|
"""Write language features and canonical tools to ``meta/info.json``.
|
||||||
|
|
||||||
|
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||||
|
``language_events`` to parquet shards. The metadata must advertise
|
||||||
|
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||||
|
cast against the old schema and fail on the extra parquet columns.
|
||||||
|
"""
|
||||||
|
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||||
|
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||||
|
|
||||||
|
info_path = root / "meta" / "info.json"
|
||||||
|
if not info_path.exists():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
info = load_info(root)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
merged_features = {**info.features, **language_feature_info()}
|
||||||
|
if merged_features != info.features:
|
||||||
|
info.features = merged_features
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
existing = info.tools or []
|
||||||
|
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||||
|
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||||
|
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if changed:
|
||||||
|
write_info(info, root)
|
||||||
|
print(
|
||||||
|
"[annotate] meta/info.json: "
|
||||||
|
f"language_features={list(language_feature_info())}, "
|
||||||
|
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_module_phase(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
records: list[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
module: Any,
|
||||||
|
) -> PhaseResult:
|
||||||
|
if not module.enabled:
|
||||||
|
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||||
|
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||||
|
n = len(records)
|
||||||
|
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||||
|
print(
|
||||||
|
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||||
|
i, record = idx_record
|
||||||
|
ep_start = time.time()
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
return i, record.episode_index, time.time() - ep_start
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
if parallelism == 1:
|
||||||
|
for i, record in enumerate(records, 1):
|
||||||
|
_, ep_idx, elapsed = _do((i, record))
|
||||||
|
processed += 1
|
||||||
|
print(
|
||||||
|
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||||
|
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
i, ep_idx, elapsed = fut.result()
|
||||||
|
processed += 1
|
||||||
|
print(
|
||||||
|
f"[annotate] {name} episode {processed}/{n} "
|
||||||
|
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
total = time.time() - t0
|
||||||
|
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||||
|
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||||
|
|
||||||
|
def _run_plan_update_phase( # noqa: PLR0915
|
||||||
|
self, records: list[EpisodeRecord], staging_dir: Path
|
||||||
|
) -> PhaseResult:
|
||||||
|
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||||
|
|
||||||
|
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||||
|
produced the timestamps. This phase therefore calls back into the
|
||||||
|
``plan`` module with the interjection timestamps so its existing
|
||||||
|
prompt path is reused.
|
||||||
|
"""
|
||||||
|
if not self.plan.enabled or not self.interjections.enabled:
|
||||||
|
return PhaseResult(
|
||||||
|
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||||
|
)
|
||||||
|
processed = 0
|
||||||
|
for record in records:
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
interjection_rows = [
|
||||||
|
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||||
|
]
|
||||||
|
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||||
|
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||||
|
if interjection_times:
|
||||||
|
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||||
|
processed += 1
|
||||||
|
# Episodes without any interjections are skipped (no plan refresh
|
||||||
|
# needed); count them so the summary's processed+skipped == total.
|
||||||
|
return PhaseResult(
|
||||||
|
name="plan_update",
|
||||||
|
episodes_processed=processed,
|
||||||
|
episodes_skipped=len(records) - processed,
|
||||||
|
)
|
||||||
494
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
494
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Keyframe extraction for the annotation pipeline.
|
||||||
|
|
||||||
|
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||||
|
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||||
|
visual content. The pipeline shares one provider across modules and one
|
||||||
|
episode at a time, with a small per-episode cache so multiple modules
|
||||||
|
querying the same timestamp pay decode cost once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.video_utils import decode_video_frames
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameProvider(Protocol):
|
||||||
|
"""Decodes camera frames at episode-relative timestamps."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||||
|
|
||||||
|
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||||
|
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||||
|
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||||
|
boundary.
|
||||||
|
|
||||||
|
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||||
|
to the provider's default camera so existing single-camera callers
|
||||||
|
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||||
|
|
||||||
|
Sampling is uniform across the episode duration. Frames are
|
||||||
|
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||||
|
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||||
|
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||||
|
no camera available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _NullProvider:
|
||||||
|
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def null_provider() -> FrameProvider:
|
||||||
|
return _NullProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoFrameProvider:
|
||||||
|
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||||
|
|
||||||
|
By default the *first* camera key is used for the ``plan`` module
|
||||||
|
(subtask decomposition) and the ``interjections`` module (interjection
|
||||||
|
scenarios) — those prompts care about *what is happening*, not which
|
||||||
|
angle. The ``vqa`` module instead iterates over every camera in
|
||||||
|
:attr:`camera_keys` so each frame's
|
||||||
|
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||||
|
grounded against.
|
||||||
|
|
||||||
|
``camera_key`` overrides the default-camera choice but does not restrict
|
||||||
|
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||||
|
``video_for_episode`` to read a non-default stream.
|
||||||
|
|
||||||
|
Caches up to ``cache_size`` decoded frames per process to keep
|
||||||
|
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: Path
|
||||||
|
camera_key: str | None = None
|
||||||
|
tolerance_s: float = 1e-2
|
||||||
|
cache_size: int = 256
|
||||||
|
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
|
||||||
|
# concurrency- and crash-safe default for the pipeline's threaded
|
||||||
|
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
|
||||||
|
# decoder when the build is known thread-safe.
|
||||||
|
video_backend: str | None = None
|
||||||
|
_meta: Any = field(default=None, init=False, repr=False)
|
||||||
|
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||||
|
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||||
|
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||||
|
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||||
|
# one-shot warn flag against concurrent updates from worker threads.
|
||||||
|
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||||
|
|
||||||
|
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||||
|
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||||
|
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||||
|
# only for video-stored cameras. Image-stored cameras (also in
|
||||||
|
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||||
|
# default — to video keys.
|
||||||
|
keys = list(self._meta.video_keys)
|
||||||
|
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||||
|
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||||
|
# trust them — the key is by definition known to exist on the dataset.
|
||||||
|
if not keys and self.camera_key:
|
||||||
|
keys = [self.camera_key]
|
||||||
|
self._camera_keys = keys
|
||||||
|
if self.camera_key is None:
|
||||||
|
self.camera_key = keys[0] if keys else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""All ``observation.images.*`` keys available on this dataset."""
|
||||||
|
return list(self._camera_keys)
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
target = camera_key if camera_key is not None else self.camera_key
|
||||||
|
if not timestamps or target is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
out: list[Any] = []
|
||||||
|
misses: list[float] = []
|
||||||
|
miss_indices: list[int] = []
|
||||||
|
with self._lock:
|
||||||
|
for i, ts in enumerate(timestamps):
|
||||||
|
key = (record.episode_index, target, round(float(ts), 6))
|
||||||
|
cached = self._cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
out.append(cached)
|
||||||
|
else:
|
||||||
|
out.append(None)
|
||||||
|
misses.append(float(ts))
|
||||||
|
miss_indices.append(i)
|
||||||
|
|
||||||
|
if misses:
|
||||||
|
decoded = self._decode(record.episode_index, misses, target)
|
||||||
|
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||||
|
# or an empty list if decoding failed wholesale. A partial list
|
||||||
|
# would mean a frame/timestamp misalignment, so only pair them up
|
||||||
|
# when the counts match (``strict=True`` then guards regressions).
|
||||||
|
if len(decoded) == len(miss_indices):
|
||||||
|
with self._lock:
|
||||||
|
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||||
|
out[i] = frame
|
||||||
|
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||||
|
if len(self._cache) >= self.cache_size:
|
||||||
|
self._cache.pop(next(iter(self._cache)))
|
||||||
|
self._cache[key] = frame
|
||||||
|
# filter out any None left over from decode failures
|
||||||
|
return [frame for frame in out if frame is not None]
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||||
|
|
||||||
|
The whole episode duration is covered; the model picks subtask
|
||||||
|
boundaries from the temporal pooling it does internally. Frames are
|
||||||
|
``torch.Tensor`` (see :meth:`frames_at`).
|
||||||
|
"""
|
||||||
|
target = camera_key if camera_key is not None else self.camera_key
|
||||||
|
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||||
|
return []
|
||||||
|
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||||
|
if n_frames == len(record.frame_timestamps):
|
||||||
|
timestamps = list(record.frame_timestamps)
|
||||||
|
else:
|
||||||
|
t0 = record.frame_timestamps[0]
|
||||||
|
t_last = record.frame_timestamps[-1]
|
||||||
|
if t_last <= t0:
|
||||||
|
timestamps = [float(t0)] * n_frames
|
||||||
|
else:
|
||||||
|
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||||
|
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||||
|
return self.frames_at(record, timestamps, camera_key=target)
|
||||||
|
|
||||||
|
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||||
|
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||||
|
|
||||||
|
Returns ``None`` if the dataset has no video tracks. Skips
|
||||||
|
re-extract when the cached clip already exists. Re-encodes to
|
||||||
|
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||||
|
downstream video processor — stream-copy would inherit the
|
||||||
|
source codec (often AV1 in modern LeRobot datasets), which
|
||||||
|
vllm's libav build cannot decode.
|
||||||
|
"""
|
||||||
|
import subprocess # noqa: PLC0415
|
||||||
|
|
||||||
|
if self.camera_key is None:
|
||||||
|
return None
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||||
|
if out_path.exists() and out_path.stat().st_size > 0:
|
||||||
|
return out_path
|
||||||
|
ep = self._meta.episodes[record.episode_index]
|
||||||
|
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||||
|
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||||
|
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-loglevel",
|
||||||
|
"error",
|
||||||
|
"-ss",
|
||||||
|
f"{from_timestamp:.3f}",
|
||||||
|
"-to",
|
||||||
|
f"{to_timestamp:.3f}",
|
||||||
|
"-i",
|
||||||
|
str(src),
|
||||||
|
"-c:v",
|
||||||
|
"libx264",
|
||||||
|
"-preset",
|
||||||
|
"ultrafast",
|
||||||
|
"-crf",
|
||||||
|
"23",
|
||||||
|
"-pix_fmt",
|
||||||
|
"yuv420p",
|
||||||
|
"-an",
|
||||||
|
str(out_path),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, timeout=300)
|
||||||
|
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||||
|
return None
|
||||||
|
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||||
|
|
||||||
|
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||||
|
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||||
|
|
||||||
|
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||||
|
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||||
|
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||||
|
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||||
|
"""
|
||||||
|
ep = self._meta.episodes[episode_index]
|
||||||
|
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||||
|
shifted = [from_timestamp + ts for ts in timestamps]
|
||||||
|
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||||
|
|
||||||
|
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
|
||||||
|
# ThreadPoolExecutor and the in-process decoders are unsafe there:
|
||||||
|
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
|
||||||
|
# (a crash no try/except can catch), PyAV can likewise segfault on
|
||||||
|
# AV1, and lerobot's ``pyav`` backend routes through the removed
|
||||||
|
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
|
||||||
|
# out per frame: each decode is an isolated child process, so it is
|
||||||
|
# both crash-safe and concurrency-safe. ``video_backend`` can pin
|
||||||
|
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
|
||||||
|
# build is safe.
|
||||||
|
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
|
||||||
|
|
||||||
|
exc: Exception | None = None
|
||||||
|
for backend in chain:
|
||||||
|
try:
|
||||||
|
if backend == "ffmpeg":
|
||||||
|
return _decode_frames_ffmpeg(video_path, shifted)
|
||||||
|
if backend in ("pyav", "av"):
|
||||||
|
return _decode_frames_av(video_path, shifted)
|
||||||
|
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||||
|
decoded = decode_video_frames(
|
||||||
|
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||||
|
)
|
||||||
|
return list(decoded)
|
||||||
|
except Exception as e: # noqa: PERF203
|
||||||
|
exc = e
|
||||||
|
|
||||||
|
# Every backend raised. Log loudly the first time so a silent
|
||||||
|
# vqa-module no-op (every prompt skipped because frames_at returned
|
||||||
|
# []) is debuggable from the job log instead of post-hoc parquet
|
||||||
|
# inspection. Subsequent failures stay quiet.
|
||||||
|
with self._lock:
|
||||||
|
already_warned = getattr(self, "_warned_decode_fail", False)
|
||||||
|
if not already_warned:
|
||||||
|
self._warned_decode_fail = True
|
||||||
|
if not already_warned:
|
||||||
|
logger.warning(
|
||||||
|
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backends=%s: %s",
|
||||||
|
episode_index,
|
||||||
|
camera_key,
|
||||||
|
video_path,
|
||||||
|
chain,
|
||||||
|
exc,
|
||||||
|
exc_info=exc,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def make_frame_provider(
|
||||||
|
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||||
|
) -> FrameProvider:
|
||||||
|
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||||
|
try:
|
||||||
|
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||||
|
except Exception:
|
||||||
|
return null_provider()
|
||||||
|
if provider.camera_key is None:
|
||||||
|
return null_provider()
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||||
|
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
|
||||||
|
|
||||||
|
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
|
||||||
|
piping a single PNG to stdout. Unlike the in-process decoders this
|
||||||
|
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
|
||||||
|
modern LeRobot datasets use) where torchcodec raises and PyAV can
|
||||||
|
SIGSEGV, and a crash stays isolated to the child process — a non-zero
|
||||||
|
exit is a catchable error, not a segfault of the whole job. Returns one
|
||||||
|
``(C, H, W)`` uint8 tensor per timestamp.
|
||||||
|
"""
|
||||||
|
import io # noqa: PLC0415
|
||||||
|
import subprocess # noqa: PLC0415
|
||||||
|
|
||||||
|
import numpy as np # noqa: PLC0415
|
||||||
|
|
||||||
|
frames: list[Any] = []
|
||||||
|
for ts in timestamps:
|
||||||
|
proc = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostdin",
|
||||||
|
"-loglevel",
|
||||||
|
"error",
|
||||||
|
"-ss",
|
||||||
|
f"{max(ts, 0.0):.3f}",
|
||||||
|
"-i",
|
||||||
|
str(video_path),
|
||||||
|
"-frames:v",
|
||||||
|
"1",
|
||||||
|
"-f",
|
||||||
|
"image2pipe",
|
||||||
|
"-vcodec",
|
||||||
|
"png",
|
||||||
|
"pipe:1",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
check=True,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
if not proc.stdout:
|
||||||
|
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
|
||||||
|
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
|
||||||
|
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||||
|
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||||
|
|
||||||
|
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||||
|
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||||
|
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
|
||||||
|
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
|
||||||
|
fallback; this stays available behind ``video_backend="pyav"``. Returns
|
||||||
|
one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||||
|
"""
|
||||||
|
import av # noqa: PLC0415
|
||||||
|
|
||||||
|
first_ts = min(timestamps)
|
||||||
|
last_ts = max(timestamps)
|
||||||
|
loaded_frames: list[torch.Tensor] = []
|
||||||
|
loaded_ts: list[float] = []
|
||||||
|
with av.open(str(video_path)) as container:
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
# Seek to the keyframe at or before the first requested timestamp.
|
||||||
|
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||||
|
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||||
|
for idx, frame in enumerate(container.decode(stream)):
|
||||||
|
ts = frame.time
|
||||||
|
if ts is None:
|
||||||
|
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||||
|
loaded_ts.append(ts)
|
||||||
|
loaded_frames.append(
|
||||||
|
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||||
|
)
|
||||||
|
if ts >= last_ts:
|
||||||
|
break
|
||||||
|
if not loaded_frames:
|
||||||
|
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||||
|
ts_tensor = torch.tensor(loaded_ts)
|
||||||
|
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||||
|
|
||||||
|
|
||||||
|
def _frame_to_pil(frame: Any) -> Any:
|
||||||
|
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||||
|
|
||||||
|
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||||
|
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||||
|
the VLM-message boundary, because the chat backends expect PIL images /
|
||||||
|
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||||
|
"""
|
||||||
|
if not isinstance(frame, torch.Tensor):
|
||||||
|
return frame
|
||||||
|
array = frame.detach().cpu()
|
||||||
|
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||||
|
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||||
|
if array.shape[-1] == 1:
|
||||||
|
array = array.squeeze(-1)
|
||||||
|
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||||
|
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||||
|
|
||||||
|
|
||||||
|
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||||
|
|
||||||
|
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||||
|
into a content array without a separate emptiness check.
|
||||||
|
"""
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||||
|
|
||||||
|
|
||||||
|
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||||
|
"""Wrap a video file URL as one ``video_url`` block.
|
||||||
|
|
||||||
|
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||||
|
ktransformers serve), where the server handles frame sampling.
|
||||||
|
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return []
|
||||||
|
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .general_vqa import GeneralVqaModule
|
||||||
|
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||||
|
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GeneralVqaModule",
|
||||||
|
"InterjectionsAndSpeechModule",
|
||||||
|
"PlanSubtasksMemoryModule",
|
||||||
|
]
|
||||||
@@ -0,0 +1,228 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``vqa`` module: general VQA at a timed cadence.
|
||||||
|
|
||||||
|
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||||
|
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||||
|
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||||
|
window. For datasets with multiple cameras, every anchored frame produces
|
||||||
|
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||||
|
generated against that camera's frame and stamped with the matching
|
||||||
|
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||||
|
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||||
|
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||||
|
|
||||||
|
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||||
|
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||||
|
|
||||||
|
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||||
|
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||||
|
whose schema depends on the question type. Malformed JSON triggers one
|
||||||
|
retry inside :meth:`VlmClient.generate_json`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import VqaConfig
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
|
from ..prompts import load as load_prompt
|
||||||
|
from ..reader import EpisodeRecord
|
||||||
|
from ..staging import EpisodeStaging
|
||||||
|
from ..validator import classify_vqa_answer
|
||||||
|
from ..vlm_client import VlmClient
|
||||||
|
|
||||||
|
|
||||||
|
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||||
|
"""Return the relative frame indices to anchor VQA emissions to.
|
||||||
|
|
||||||
|
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||||
|
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||||
|
available source frame timestamp.
|
||||||
|
"""
|
||||||
|
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||||
|
return []
|
||||||
|
t0 = frame_timestamps[0]
|
||||||
|
t_last = frame_timestamps[-1]
|
||||||
|
period = 1.0 / hz
|
||||||
|
indices: list[int] = []
|
||||||
|
t = t0
|
||||||
|
while t <= t_last + 1e-9:
|
||||||
|
# find the index of the nearest frame to t
|
||||||
|
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||||
|
for offset in range(k):
|
||||||
|
j = nearest_i + offset
|
||||||
|
if j >= len(frame_timestamps):
|
||||||
|
break
|
||||||
|
if not indices or indices[-1] != j:
|
||||||
|
indices.append(j)
|
||||||
|
t += period
|
||||||
|
# dedupe while preserving order
|
||||||
|
seen: set[int] = set()
|
||||||
|
deduped: list[int] = []
|
||||||
|
for i in indices:
|
||||||
|
if i in seen:
|
||||||
|
continue
|
||||||
|
seen.add(i)
|
||||||
|
deduped.append(i)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneralVqaModule:
|
||||||
|
"""Emit grounded VQA pairs at a timed cadence."""
|
||||||
|
|
||||||
|
vlm: VlmClient
|
||||||
|
config: VqaConfig
|
||||||
|
seed: int = 1729
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.config.enabled
|
||||||
|
|
||||||
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||||
|
if not record.frame_timestamps:
|
||||||
|
staging.write("vqa", [])
|
||||||
|
return
|
||||||
|
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||||
|
anchor_idx = _emission_anchor_indices(
|
||||||
|
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||||
|
)
|
||||||
|
cameras = self._target_cameras()
|
||||||
|
if not cameras:
|
||||||
|
# No camera available — emit nothing rather than producing
|
||||||
|
# untagged rows that would fail validation. Surface a loud one-
|
||||||
|
# time warning so this is never silently a no-op.
|
||||||
|
if not getattr(self, "_warned_no_camera", False):
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"vqa module found no cameras on the frame provider — "
|
||||||
|
"every episode will emit zero VQA rows. Check that the "
|
||||||
|
"dataset declares observation.images.* features in "
|
||||||
|
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||||
|
"CLI now also seeds the cameras list as a fallback."
|
||||||
|
)
|
||||||
|
self._warned_no_camera = True
|
||||||
|
staging.write("vqa", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build all messages first (one per (frame, camera)), then issue them
|
||||||
|
# as a single batched generate_json call so the client can fan them
|
||||||
|
# out concurrently.
|
||||||
|
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||||
|
for idx in anchor_idx:
|
||||||
|
ts = float(record.frame_timestamps[idx])
|
||||||
|
qtype = rng.choice(self.config.question_types)
|
||||||
|
for camera in cameras:
|
||||||
|
messages = self._build_messages(record, qtype, ts, camera)
|
||||||
|
# Skip cameras that decoded to zero frames at this ts: no point
|
||||||
|
# asking the VLM to ground a bbox without an image.
|
||||||
|
if not _has_image_block(messages):
|
||||||
|
continue
|
||||||
|
per_call.append((ts, camera, qtype, messages))
|
||||||
|
|
||||||
|
if not per_call:
|
||||||
|
staging.write("vqa", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||||
|
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||||
|
qa = self._postprocess(result)
|
||||||
|
if qa is None:
|
||||||
|
continue
|
||||||
|
question, answer = qa
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": ts,
|
||||||
|
"camera": camera,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": json.dumps(answer, sort_keys=True),
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": ts,
|
||||||
|
"camera": camera,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
staging.write("vqa", rows)
|
||||||
|
|
||||||
|
def _target_cameras(self) -> list[str]:
|
||||||
|
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||||
|
|
||||||
|
Defaults to every camera the provider exposes. Datasets with no
|
||||||
|
cameras (or test/null providers) yield an empty list, which makes
|
||||||
|
``run_episode`` a no-op.
|
||||||
|
"""
|
||||||
|
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||||
|
|
||||||
|
def _build_messages(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
question_type: str,
|
||||||
|
frame_timestamp: float,
|
||||||
|
camera_key: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
prompt = load_prompt("module_3_vqa").format(
|
||||||
|
episode_task=record.episode_task,
|
||||||
|
question_type=question_type,
|
||||||
|
)
|
||||||
|
images = self.frame_provider.frames_at(
|
||||||
|
record, [frame_timestamp], camera_key=camera_key
|
||||||
|
)
|
||||||
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return None
|
||||||
|
question = result.get("question")
|
||||||
|
answer = result.get("answer")
|
||||||
|
if not isinstance(question, str) or not question.strip():
|
||||||
|
return None
|
||||||
|
if not isinstance(answer, dict):
|
||||||
|
return None
|
||||||
|
# The validator will enforce shape; here we just sanity-check that the
|
||||||
|
# answer matches *some* known shape so we can drop garbage early.
|
||||||
|
if classify_vqa_answer(answer) is None:
|
||||||
|
return None
|
||||||
|
return question.strip(), answer
|
||||||
|
|
||||||
|
|
||||||
|
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Return True if any user content block is a populated image block."""
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
continue
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "image":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||||
|
|
||||||
|
Two sub-passes:
|
||||||
|
|
||||||
|
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||||
|
canonical task). No interjection row — the canonical task is already the
|
||||||
|
user utterance from ``meta/tasks.parquet``.
|
||||||
|
|
||||||
|
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||||
|
{role:user, style:interjection, content:<text>}
|
||||||
|
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||||
|
Both rows go in ``language_events`` at the same timestamp.
|
||||||
|
|
||||||
|
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||||
|
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import InterjectionsConfig
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
|
from ..prompts import load as load_prompt
|
||||||
|
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||||
|
from ..staging import EpisodeStaging
|
||||||
|
from ..vlm_client import VlmClient
|
||||||
|
from ..writer import speech_atom
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterjectionsAndSpeechModule:
|
||||||
|
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||||
|
|
||||||
|
vlm: VlmClient
|
||||||
|
config: InterjectionsConfig
|
||||||
|
seed: int = 1729
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.config.enabled
|
||||||
|
|
||||||
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
if record.frame_timestamps:
|
||||||
|
t0 = float(record.frame_timestamps[0])
|
||||||
|
initial = self._initial_speech(record)
|
||||||
|
if initial:
|
||||||
|
rows.append(speech_atom(t0, initial))
|
||||||
|
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||||
|
# interjection prompt can ground itself in the actual current
|
||||||
|
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||||
|
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||||
|
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||||
|
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||||
|
staging.write("interjections", rows)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||||
|
current: str | None = None
|
||||||
|
for span in spans:
|
||||||
|
if float(span["start"]) <= t:
|
||||||
|
current = span.get("text")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return current
|
||||||
|
|
||||||
|
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||||
|
prompt = load_prompt("module_2_initial_speech").format(
|
||||||
|
episode_task=record.episode_task,
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||||
|
text = result["text"].strip()
|
||||||
|
if text:
|
||||||
|
return text
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _mid_episode_interjections(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
subtask_spans: Sequence[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Generate interjections aligned with the actual demo trajectory.
|
||||||
|
|
||||||
|
Teleop data is frozen — the robot already executed every step in
|
||||||
|
the video. A *counterfactual* interjection like "actually skip
|
||||||
|
the wipe" contradicts what then happens in the video, which is
|
||||||
|
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||||
|
|
||||||
|
Instead, anchor every interjection at a subtask boundary and
|
||||||
|
write it as a natural user request for the *upcoming* subtask.
|
||||||
|
The robot's visible next behavior IS the interjection's effect,
|
||||||
|
so the training signal stays consistent: interjection text →
|
||||||
|
plan refresh → action stream all line up.
|
||||||
|
"""
|
||||||
|
if self.config.max_interjections_per_episode <= 0:
|
||||||
|
return []
|
||||||
|
if len(subtask_spans) < 2:
|
||||||
|
# Need at least one transition (subtask 0 → subtask 1).
|
||||||
|
return []
|
||||||
|
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||||
|
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||||
|
|
||||||
|
# Boundaries: the start time of every subtask except the first
|
||||||
|
# (which is just t0 and is covered by the initial-task speech atom).
|
||||||
|
boundaries: list[tuple[float, str, str]] = []
|
||||||
|
for i in range(1, len(subtask_spans)):
|
||||||
|
ts = float(subtask_spans[i]["start"])
|
||||||
|
if ts < self.config.interjection_min_t:
|
||||||
|
continue
|
||||||
|
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||||
|
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||||
|
if not next_text:
|
||||||
|
continue
|
||||||
|
boundaries.append((ts, prev_text, next_text))
|
||||||
|
if not boundaries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||||
|
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||||
|
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for t, prev_subtask, next_subtask in chosen:
|
||||||
|
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||||
|
# Window straddles the boundary so the VLM sees the end of the
|
||||||
|
# previous subtask and the start of the next one — same
|
||||||
|
# conditioning the policy will see at training time.
|
||||||
|
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||||
|
prompt = load_prompt("module_2_interjection").format(
|
||||||
|
episode_task=record.episode_task,
|
||||||
|
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||||
|
next_subtask=next_subtask,
|
||||||
|
timestamp=t_snap,
|
||||||
|
window_seconds=self.config.interjection_window_seconds,
|
||||||
|
)
|
||||||
|
images = self.frame_provider.frames_at(record, window_ts)
|
||||||
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
continue
|
||||||
|
interjection_text = result.get("interjection")
|
||||||
|
speech_text = result.get("speech")
|
||||||
|
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||||
|
continue
|
||||||
|
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||||
|
continue
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": interjection_text.strip(),
|
||||||
|
"style": "interjection",
|
||||||
|
"timestamp": t_snap,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||||
|
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||||
|
|
||||||
|
The window straddles the subtask boundary the interjection sits
|
||||||
|
on: roughly half the frames cover the end of the previous
|
||||||
|
subtask, half cover the start of the next one. The VLM therefore
|
||||||
|
sees BOTH what just finished AND what's about to start, which is
|
||||||
|
the conditioning we need to write a natural "now please do X"
|
||||||
|
request that matches the visible upcoming behavior.
|
||||||
|
"""
|
||||||
|
if not frame_timestamps:
|
||||||
|
return [t_anchor]
|
||||||
|
n = max(1, int(self.config.interjection_window_frames))
|
||||||
|
if n == 1:
|
||||||
|
return [t_anchor]
|
||||||
|
window = float(self.config.interjection_window_seconds)
|
||||||
|
step = window / max(1, n - 1)
|
||||||
|
# Center the window on the anchor so half lands before, half after.
|
||||||
|
start_offset = -window / 2.0
|
||||||
|
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||||
|
last_ts = float(frame_timestamps[-1])
|
||||||
|
snapped: list[float] = []
|
||||||
|
seen: set[float] = set()
|
||||||
|
for tgt in targets:
|
||||||
|
clamped = min(last_ts, max(0.0, tgt))
|
||||||
|
t = snap_to_frame(clamped, frame_timestamps)
|
||||||
|
if t not in seen:
|
||||||
|
seen.add(t)
|
||||||
|
snapped.append(t)
|
||||||
|
return snapped or [t_anchor]
|
||||||
@@ -0,0 +1,710 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import PlanConfig
|
||||||
|
from ..frames import (
|
||||||
|
FrameProvider,
|
||||||
|
VideoFrameProvider,
|
||||||
|
null_provider,
|
||||||
|
to_image_blocks,
|
||||||
|
to_video_block,
|
||||||
|
to_video_url_block,
|
||||||
|
)
|
||||||
|
from ..prompts import load as load_prompt
|
||||||
|
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||||
|
from ..staging import EpisodeStaging
|
||||||
|
from ..vlm_client import VlmClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlanSubtasksMemoryModule:
|
||||||
|
"""Generate subtask spans, plan, and memory rows.
|
||||||
|
|
||||||
|
All output is persistent (lives in ``language_persistent``):
|
||||||
|
|
||||||
|
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||||
|
(snapped to an exact frame).
|
||||||
|
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||||
|
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||||
|
the ``interjections`` module completes).
|
||||||
|
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||||
|
timestamp from the second subtask onward).
|
||||||
|
"""
|
||||||
|
|
||||||
|
vlm: VlmClient
|
||||||
|
config: PlanConfig
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.config.enabled
|
||||||
|
|
||||||
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
# Resolve the task that drives every other ``plan``-module prompt.
|
||||||
|
# May be the canonical ``record.episode_task`` (default), or a fresh
|
||||||
|
# description derived from the video when the canonical task is
|
||||||
|
# empty / placeholder / forced-off (see PlanConfig.derive_task_*).
|
||||||
|
effective_task = self._resolve_effective_task(record)
|
||||||
|
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||||
|
# message renderer rotates ``${task}`` deterministically through
|
||||||
|
# them so the policy sees diverse phrasings during training.
|
||||||
|
# Two paths:
|
||||||
|
# * ``task_aug_axes.enabled=True`` — structured 5-axis taxonomy
|
||||||
|
# (synonym / omit_arm / omit_orientation / omit_grasp_method
|
||||||
|
# / combined). Replaces the free-form rephrasings flow.
|
||||||
|
# * Otherwise — free-form ``n_task_rephrasings`` (original).
|
||||||
|
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||||
|
axes_cfg = self.config.task_aug_axes
|
||||||
|
if axes_cfg.enabled and effective_task:
|
||||||
|
variants = self._generate_task_aug_by_axes(effective_task, axes_cfg)
|
||||||
|
seen: set[str] = set()
|
||||||
|
ordered = [effective_task, *variants]
|
||||||
|
for phrasing in ordered:
|
||||||
|
key = phrasing.strip()
|
||||||
|
if not key or key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": key,
|
||||||
|
"style": "task_aug",
|
||||||
|
"timestamp": t0,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||||
|
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||||
|
# Always include the effective task itself as the first variant
|
||||||
|
# so the rotation is guaranteed to cover the source-of-truth
|
||||||
|
# phrasing, not just synthetic alternatives.
|
||||||
|
seen = set()
|
||||||
|
ordered = [effective_task, *rephrasings]
|
||||||
|
for phrasing in ordered:
|
||||||
|
key = phrasing.strip()
|
||||||
|
if not key or key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": key,
|
||||||
|
"style": "task_aug",
|
||||||
|
"timestamp": t0,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Phase 1a + 1b: structured per-subtask action records
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# When enabled, for every subtask span we ask the VLM for a typed
|
||||||
|
# ActionRecord (verb / object / arm / grasp_type / destination /
|
||||||
|
# mistake). A deterministic Python template renders the record
|
||||||
|
# back to canonical subtask text. The render replaces the
|
||||||
|
# free-form subtask text (cleaner conditioning) and the typed
|
||||||
|
# record is emitted as a separate row for downstream use.
|
||||||
|
records_cfg = self.config.action_records
|
||||||
|
action_records: list[dict[str, Any] | None] = [None] * len(subtask_spans)
|
||||||
|
if records_cfg.enabled and subtask_spans:
|
||||||
|
for i, span in enumerate(subtask_spans):
|
||||||
|
rec = self._extract_action_record(record, span, effective_task)
|
||||||
|
if rec is None:
|
||||||
|
continue
|
||||||
|
action_records[i] = rec
|
||||||
|
if records_cfg.replace_subtask_text:
|
||||||
|
canonical_text = self._render_action_record_to_subtask_text(rec)
|
||||||
|
if canonical_text:
|
||||||
|
span["text"] = canonical_text
|
||||||
|
|
||||||
|
# subtask rows (may now reflect canonical-rendered text)
|
||||||
|
for i, span in enumerate(subtask_spans):
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": span["text"],
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if records_cfg.enabled and records_cfg.emit_record_row and action_records[i] is not None:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": json.dumps(action_records[i], sort_keys=True),
|
||||||
|
"style": "action_record",
|
||||||
|
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Plan rows at every subtask boundary — including t=0 (start of
|
||||||
|
# the first subtask). Because the plan is just a numbered list
|
||||||
|
# of *still-todo* subtasks, re-emitting at each boundary makes
|
||||||
|
# the active plan shrink as work progresses: at frame t the
|
||||||
|
# rendered ``${plan}`` is the most recent emission, which
|
||||||
|
# contains exactly the subtasks that started at or after the
|
||||||
|
# current span. Saves the runtime from having to derive
|
||||||
|
# "what's still left" at inference time.
|
||||||
|
for span in subtask_spans:
|
||||||
|
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||||
|
plan_text = self._generate_plan(
|
||||||
|
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||||
|
)
|
||||||
|
if plan_text is not None:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": plan_text,
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": float(boundary_t),
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# memory rows at every subtask boundary except the very first start
|
||||||
|
prior_memory = ""
|
||||||
|
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||||
|
completed = subtask_spans[i - 1]["text"]
|
||||||
|
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||||
|
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||||
|
if mem_text:
|
||||||
|
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": mem_text,
|
||||||
|
"style": "memory",
|
||||||
|
"timestamp": ts,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
prior_memory = mem_text
|
||||||
|
staging.write("plan", rows)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Task derivation + rephrasings
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"debug",
|
||||||
|
"test",
|
||||||
|
"tbd",
|
||||||
|
"todo",
|
||||||
|
"n/a",
|
||||||
|
"na",
|
||||||
|
"untitled",
|
||||||
|
"unnamed",
|
||||||
|
"default",
|
||||||
|
"placeholder",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||||
|
"""Decide which task string drives the ``plan`` module for this episode.
|
||||||
|
|
||||||
|
Returns the user-supplied ``record.episode_task`` unless
|
||||||
|
``derive_task_from_video`` says otherwise (see config docstring).
|
||||||
|
Falls back gracefully to the canonical task if video derivation
|
||||||
|
fails.
|
||||||
|
"""
|
||||||
|
canonical = (record.episode_task or "").strip()
|
||||||
|
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||||
|
if mode == "always":
|
||||||
|
derived = self._derive_task_from_video(record)
|
||||||
|
return derived or canonical
|
||||||
|
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||||
|
derived = self._derive_task_from_video(record)
|
||||||
|
if derived:
|
||||||
|
return derived
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
def _task_seems_bad(self, task: str) -> bool:
|
||||||
|
if not task:
|
||||||
|
return True
|
||||||
|
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||||
|
return True
|
||||||
|
return task.lower() in self._PLACEHOLDER_TASKS
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# VLM call helpers (factored out: every ``plan``-module prompt below follows
|
||||||
|
# the same "build messages → single VLM call → pull a named field"
|
||||||
|
# shape, only differing in field name + post-processing).
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||||
|
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||||
|
|
||||||
|
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||||
|
dance every prompt-call site needs.
|
||||||
|
"""
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result.get(field)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||||
|
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||||
|
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||||
|
|
||||||
|
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||||
|
"""User message combining the episode video block with ``prompt``."""
|
||||||
|
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||||
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||||
|
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||||
|
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
|
||||||
|
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||||
|
|
||||||
|
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||||
|
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||||
|
if n <= 0 or not base_task:
|
||||||
|
return []
|
||||||
|
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
|
||||||
|
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||||
|
if not isinstance(raw, list):
|
||||||
|
return []
|
||||||
|
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||||
|
return [s for s in out if s][:n]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1a + 1b: structured per-subtask action records
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _extract_action_record(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
span: dict[str, Any],
|
||||||
|
episode_task: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Ask the VLM to extract a typed ``ActionRecord`` from a subtask span.
|
||||||
|
|
||||||
|
Sends ``frames_per_subtask`` frames uniformly sampled from
|
||||||
|
``[span.start, span.end]`` plus the canonical subtask text. The
|
||||||
|
VLM is constrained to verb + grasp vocabularies from the config
|
||||||
|
— invalid values are silently dropped at this layer (the
|
||||||
|
validator catches structural problems pre-write).
|
||||||
|
|
||||||
|
Returns ``None`` when the call fails or the VLM returns something
|
||||||
|
unrecognizable; callers fall back to the free-form subtask text.
|
||||||
|
"""
|
||||||
|
cfg = self.config.action_records
|
||||||
|
start_t = float(span.get("start", 0.0))
|
||||||
|
end_t = float(span.get("end", start_t))
|
||||||
|
duration = max(0.0, end_t - start_t)
|
||||||
|
|
||||||
|
# Uniform timestamps within the span; fall back to a single
|
||||||
|
# center frame for very short spans.
|
||||||
|
n = max(1, int(cfg.frames_per_subtask))
|
||||||
|
if n == 1 or duration <= 0.0:
|
||||||
|
timestamps = [0.5 * (start_t + end_t)]
|
||||||
|
else:
|
||||||
|
step = duration / (n - 1)
|
||||||
|
timestamps = [start_t + i * step for i in range(n)]
|
||||||
|
frames = self.frame_provider.frames_at(record, timestamps)
|
||||||
|
if not frames:
|
||||||
|
logger.debug(
|
||||||
|
"action_record: no frames at span %.2f-%.2f for ep %s; skipping",
|
||||||
|
start_t, end_t, record.episode_index,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
prompt = load_prompt("module_1_action_record").format(
|
||||||
|
episode_task=episode_task,
|
||||||
|
subtask_text=span.get("text", ""),
|
||||||
|
start_time=start_t,
|
||||||
|
end_time=end_t,
|
||||||
|
duration=duration,
|
||||||
|
n_frames=len(frames),
|
||||||
|
verb_vocabulary=", ".join(cfg.verb_vocabulary),
|
||||||
|
grasp_vocabulary=" | ".join(f'"{g}"' for g in cfg.grasp_vocabulary),
|
||||||
|
)
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [*to_image_blocks(frames), {"type": "text", "text": prompt}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result = self.vlm.generate_json([message])[0]
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Light validation + normalisation. Verb is required; everything
|
||||||
|
# else may be null. Verb / grasp_type are clamped to the
|
||||||
|
# vocabularies (out-of-vocab → reject or null).
|
||||||
|
verb = (result.get("verb") or "").strip().lower()
|
||||||
|
if not verb or verb not in {v.lower() for v in cfg.verb_vocabulary}:
|
||||||
|
return None
|
||||||
|
obj = (result.get("object") or "").strip()
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
grasp = result.get("grasp_type")
|
||||||
|
if isinstance(grasp, str):
|
||||||
|
grasp = grasp.strip().lower()
|
||||||
|
if grasp not in {g.lower() for g in cfg.grasp_vocabulary}:
|
||||||
|
grasp = None
|
||||||
|
else:
|
||||||
|
grasp = None
|
||||||
|
arm = result.get("arm")
|
||||||
|
if isinstance(arm, str):
|
||||||
|
arm = arm.strip().lower()
|
||||||
|
if arm not in {"left", "right", "both"}:
|
||||||
|
arm = None
|
||||||
|
else:
|
||||||
|
arm = None
|
||||||
|
destination = result.get("destination")
|
||||||
|
destination = destination.strip() if isinstance(destination, str) and destination.strip() else None
|
||||||
|
mistake = result.get("mistake")
|
||||||
|
mistake = mistake.strip() if isinstance(mistake, str) and mistake.strip() else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"verb": verb,
|
||||||
|
"object": obj,
|
||||||
|
"arm": arm,
|
||||||
|
"grasp_type": grasp,
|
||||||
|
"destination": destination,
|
||||||
|
"mistake": mistake,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _render_action_record_to_subtask_text(record: dict[str, Any]) -> str:
|
||||||
|
"""Deterministic template: ``ActionRecord`` → canonical subtask text.
|
||||||
|
|
||||||
|
Mirrors the authoring guidance in ``module_1_subtasks.txt``:
|
||||||
|
imperative, drop articles / adverbs, use canonical object nouns,
|
||||||
|
append arm / grasp clauses only when present.
|
||||||
|
|
||||||
|
Examples (record → rendered text)::
|
||||||
|
|
||||||
|
{verb=pick, object=blue cube}
|
||||||
|
→ "pick blue cube"
|
||||||
|
{verb=pick, object=blue cube, arm=left, grasp_type=pinch}
|
||||||
|
→ "pick blue cube with left arm using pinch grip"
|
||||||
|
{verb=place, object=blue cube, destination=green box}
|
||||||
|
→ "place blue cube in green box"
|
||||||
|
{verb=move, object=mug, destination=stove}
|
||||||
|
→ "move mug to stove"
|
||||||
|
"""
|
||||||
|
verb = (record.get("verb") or "").strip().lower()
|
||||||
|
obj = (record.get("object") or "").strip()
|
||||||
|
arm = (record.get("arm") or "").strip().lower() if record.get("arm") else ""
|
||||||
|
grasp = (record.get("grasp_type") or "").strip().lower() if record.get("grasp_type") else ""
|
||||||
|
dest = (record.get("destination") or "").strip() if record.get("destination") else ""
|
||||||
|
|
||||||
|
if not verb:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts: list[str] = [verb]
|
||||||
|
if obj:
|
||||||
|
parts.append(obj)
|
||||||
|
if dest:
|
||||||
|
# Pick a sensible preposition per verb family.
|
||||||
|
if verb in {"place", "put", "drop", "insert", "pour", "dump"}:
|
||||||
|
parts.append(f"in {dest}")
|
||||||
|
elif verb in {"move", "transport", "reach"}:
|
||||||
|
parts.append(f"to {dest}")
|
||||||
|
else:
|
||||||
|
parts.append(f"at {dest}")
|
||||||
|
if arm == "both":
|
||||||
|
parts.append("with both arms")
|
||||||
|
elif arm in {"left", "right"}:
|
||||||
|
parts.append(f"with {arm} arm")
|
||||||
|
if grasp:
|
||||||
|
parts.append(f"using {grasp} grip")
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||||
|
"""One VLM call → variants along the 5-axis taxonomy.
|
||||||
|
|
||||||
|
Variants from all axes are flattened into a single list (the
|
||||||
|
downstream pipeline doesn't need to know about the per-axis
|
||||||
|
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||||
|
is preserved for reproducibility: synonym_paraphrase first,
|
||||||
|
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||||
|
then combined_omissions.
|
||||||
|
"""
|
||||||
|
if not base_task:
|
||||||
|
return []
|
||||||
|
prompt = load_prompt("module_1_task_aug_axes").format(
|
||||||
|
base_task=base_task,
|
||||||
|
n_synonym=axes_cfg.synonym_paraphrase,
|
||||||
|
n_omit_arm=axes_cfg.omit_arm,
|
||||||
|
n_omit_orientation=axes_cfg.omit_orientation,
|
||||||
|
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||||
|
n_combined=axes_cfg.combined_omissions,
|
||||||
|
)
|
||||||
|
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return []
|
||||||
|
ordered_axes = (
|
||||||
|
"synonym_paraphrase",
|
||||||
|
"omit_arm",
|
||||||
|
"omit_orientation",
|
||||||
|
"omit_grasp_method",
|
||||||
|
"combined_omissions",
|
||||||
|
)
|
||||||
|
flat: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for axis in ordered_axes:
|
||||||
|
entries = result.get(axis)
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
continue
|
||||||
|
for item in entries:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
key = item.strip().strip('"').strip("'")
|
||||||
|
if not key or key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
flat.append(key)
|
||||||
|
return flat
|
||||||
|
|
||||||
|
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||||
|
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||||
|
if not record.frame_timestamps:
|
||||||
|
return []
|
||||||
|
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||||
|
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||||
|
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||||
|
return (
|
||||||
|
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||||
|
if clip is not None
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||||
|
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||||
|
target_count = min(target_count, self.config.max_video_frames)
|
||||||
|
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||||
|
return to_video_block(video_frames)
|
||||||
|
|
||||||
|
def run_plan_updates(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
staging: EpisodeStaging,
|
||||||
|
interjection_times: Sequence[float],
|
||||||
|
interjection_texts: Sequence[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||||
|
|
||||||
|
Plans refresh ONLY on user interjections — subtask generation
|
||||||
|
runs ~1 Hz at inference, but plan re-emission is event-driven.
|
||||||
|
Now also forwards the interjection's own text into the prompt so
|
||||||
|
the refreshed plan can actually reflect the user's correction
|
||||||
|
(the previous version told the model "an interjection happened"
|
||||||
|
without telling it what the user said).
|
||||||
|
"""
|
||||||
|
existing = staging.read("plan")
|
||||||
|
# Pass the episode's last frame timestamp so the final subtask
|
||||||
|
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||||
|
# zero duration, and the "current subtask at refresh_t" lookup
|
||||||
|
# in ``_generate_plan`` misses any refresh that lands inside it).
|
||||||
|
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||||
|
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||||
|
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||||
|
new_rows = list(existing)
|
||||||
|
|
||||||
|
texts: list[str | None] = (
|
||||||
|
[None] * len(interjection_times)
|
||||||
|
if interjection_texts is None
|
||||||
|
else [str(t) if t else None for t in interjection_texts]
|
||||||
|
)
|
||||||
|
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||||
|
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||||
|
if t in already_planned:
|
||||||
|
continue
|
||||||
|
already_planned.add(t)
|
||||||
|
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||||
|
if plan_text is not None:
|
||||||
|
new_rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": plan_text,
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": t,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
staging.write("plan", new_rows)
|
||||||
|
|
||||||
|
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
if record.row_count == 0 or not record.frame_timestamps:
|
||||||
|
return []
|
||||||
|
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||||
|
prompt = load_prompt("module_1_subtasks").format(
|
||||||
|
episode_task=(task if task is not None else record.episode_task),
|
||||||
|
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||||
|
max_steps=self.config.plan_max_steps,
|
||||||
|
episode_duration=f"{episode_duration:.3f}",
|
||||||
|
)
|
||||||
|
messages = self._video_message(record, prompt)
|
||||||
|
spans = self._vlm_field(messages, "subtasks")
|
||||||
|
if not spans:
|
||||||
|
return []
|
||||||
|
# clamp to [t0, t_last] and sort
|
||||||
|
t0 = record.frame_timestamps[0]
|
||||||
|
t_last = record.frame_timestamps[-1]
|
||||||
|
cleaned: list[dict[str, Any]] = []
|
||||||
|
for span in spans:
|
||||||
|
try:
|
||||||
|
start = float(span["start"])
|
||||||
|
end = float(span["end"])
|
||||||
|
text = str(span["text"]).strip()
|
||||||
|
except (KeyError, ValueError, TypeError):
|
||||||
|
continue
|
||||||
|
start = max(t0, min(start, t_last))
|
||||||
|
end = max(t0, min(end, t_last))
|
||||||
|
if end < start:
|
||||||
|
start, end = end, start
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
cleaned.append({"text": text, "start": start, "end": end})
|
||||||
|
cleaned.sort(key=lambda s: s["start"])
|
||||||
|
cleaned = self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dedupe_starts_to_distinct_frames(
|
||||||
|
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Bump same-frame subtask starts onto distinct frames.
|
||||||
|
|
||||||
|
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||||
|
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||||
|
two ``style=subtask`` rows at the identical persistent
|
||||||
|
timestamp. The training-time renderer's ``active_at(t,
|
||||||
|
style=subtask)`` resolver can't disambiguate that and raises
|
||||||
|
``Ambiguous resolver for style='subtask'``.
|
||||||
|
|
||||||
|
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||||
|
if the snapped frame is already taken push the span onto the
|
||||||
|
next unused frame so both subtasks survive on distinct
|
||||||
|
timestamps. If the episode ends before a free frame is found,
|
||||||
|
the trailing span is dropped with a warning — better than
|
||||||
|
poisoning the render.
|
||||||
|
"""
|
||||||
|
if not spans:
|
||||||
|
return spans
|
||||||
|
frames = record.frame_timestamps
|
||||||
|
if not frames:
|
||||||
|
return spans
|
||||||
|
used: set[float] = set()
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for span in spans:
|
||||||
|
ts = snap_to_frame(span["start"], frames)
|
||||||
|
if ts in used:
|
||||||
|
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||||
|
if next_ts is None:
|
||||||
|
logger.warning(
|
||||||
|
"episode %d: subtask %r snapped to occupied frame "
|
||||||
|
"%.3f and no free later frame exists — dropping",
|
||||||
|
record.episode_index,
|
||||||
|
span.get("text"),
|
||||||
|
ts,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
ts = next_ts
|
||||||
|
used.add(ts)
|
||||||
|
new_span = {**span, "start": ts}
|
||||||
|
if float(new_span.get("end", ts)) < ts:
|
||||||
|
new_span["end"] = ts
|
||||||
|
out.append(new_span)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _generate_plan(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||||
|
subtask_spans: Sequence[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
refresh_t: float | None = None,
|
||||||
|
interjection: str | None = None, # noqa: ARG002
|
||||||
|
task: str | None = None, # noqa: ARG002
|
||||||
|
) -> str | None:
|
||||||
|
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||||
|
|
||||||
|
Previously this called the VLM with a prompt that asked it to
|
||||||
|
compress the subtasks into a "compact hierarchical plan". That
|
||||||
|
produced longer-than-necessary plans, cost an extra VLM round-trip
|
||||||
|
per episode (plus one per interjection on refresh), and could
|
||||||
|
diverge from the actual subtask sequence the model is going to
|
||||||
|
execute. Replacing it with a plain summarisation keeps the plan
|
||||||
|
tightly aligned with the upcoming subtasks and removes the VLM
|
||||||
|
call entirely.
|
||||||
|
|
||||||
|
Layout — short imperative fragments prefixed by "N. ":
|
||||||
|
|
||||||
|
1. <subtask 1>
|
||||||
|
2. <subtask 2>
|
||||||
|
...
|
||||||
|
|
||||||
|
On a refresh at ``refresh_t`` (called from ``run_plan_updates``
|
||||||
|
on interjection events, and from ``run_episode`` at every subtask
|
||||||
|
boundary), only subtasks whose start is at or after ``refresh_t``
|
||||||
|
are included — the plan shrinks as work progresses, so it always
|
||||||
|
describes what's left.
|
||||||
|
"""
|
||||||
|
if not subtask_spans:
|
||||||
|
return None
|
||||||
|
remaining = [
|
||||||
|
s
|
||||||
|
for s in subtask_spans
|
||||||
|
if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||||
|
]
|
||||||
|
if not remaining:
|
||||||
|
# Past the last subtask boundary on a late refresh — nothing
|
||||||
|
# left to plan; emit None so the caller skips the row.
|
||||||
|
return None
|
||||||
|
return "\n".join(
|
||||||
|
f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_memory(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
prior_memory: str,
|
||||||
|
completed: str,
|
||||||
|
remaining: Sequence[str],
|
||||||
|
*,
|
||||||
|
task: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
prompt = load_prompt("module_1_memory").format(
|
||||||
|
episode_task=(task if task is not None else record.episode_task),
|
||||||
|
prior_memory=prior_memory or "(none)",
|
||||||
|
completed_subtask=completed,
|
||||||
|
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||||
|
)
|
||||||
|
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||||
|
return memory.strip() if isinstance(memory, str) else ""
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Prompt templates loaded as plain text.
|
||||||
|
|
||||||
|
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||||
|
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||||
|
plain editors and roundtrip cleanly through ``ruff format``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
def load(name: str) -> str:
|
||||||
|
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||||
|
path = _DIR / f"{name}.txt"
|
||||||
|
return path.read_text(encoding="utf-8")
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
You are extracting a structured action record from a subtask span of a
|
||||||
|
teleoperated robot demonstration. This is Phase 1a of a two-step
|
||||||
|
process: you extract a typed record; a deterministic template then
|
||||||
|
renders it back to canonical subtask text. Your job is the PERCEPTION
|
||||||
|
step — not the language step.
|
||||||
|
|
||||||
|
The user originally asked: "{episode_task}"
|
||||||
|
The subtask span is: "{subtask_text}"
|
||||||
|
Span time window: [{start_time:.2f}s, {end_time:.2f}s]
|
||||||
|
({duration:.2f}s of robot activity)
|
||||||
|
|
||||||
|
You are shown {n_frames} frames sampled uniformly from the subtask
|
||||||
|
window. Fill in a structured record describing the action that takes
|
||||||
|
place between the first and last frame.
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
- Use ONLY information visible in the frames. Do not infer details from
|
||||||
|
outside the span. Do not extrapolate from the original task wording.
|
||||||
|
- Use canonical object names from the original task VERBATIM. Never
|
||||||
|
introduce synonyms: if the task says "cube", the record says "cube",
|
||||||
|
never "block" / "object" / "item".
|
||||||
|
- For non-applicable fields, use ``null`` (not "n/a", not "none", not
|
||||||
|
an empty string).
|
||||||
|
- For ``verb`` and ``grasp_type``, pick EXACTLY one value from the
|
||||||
|
vocabulary below. Never invent a new one.
|
||||||
|
|
||||||
|
Field schema:
|
||||||
|
|
||||||
|
verb (required) — the imperative verb of the action. Vocabulary:
|
||||||
|
{verb_vocabulary}
|
||||||
|
|
||||||
|
object (required) — the manipulated object. Use the canonical noun
|
||||||
|
from the original task above.
|
||||||
|
|
||||||
|
arm — which arm performs the action. One of:
|
||||||
|
"left" | "right" | "both" | null
|
||||||
|
Use ``null`` when the source robot is single-arm or when the arm
|
||||||
|
is genuinely not visible in the frames.
|
||||||
|
|
||||||
|
grasp_type — which grip the gripper uses on contact. One of:
|
||||||
|
{grasp_vocabulary} | null
|
||||||
|
Use ``null`` when there is no contact in this span (e.g. a pure
|
||||||
|
``move`` / ``reach`` subtask) or the grip is genuinely unclear.
|
||||||
|
|
||||||
|
destination — the target location for actions like ``place``,
|
||||||
|
``move``, ``insert``, ``pour``. Use canonical names from the
|
||||||
|
original task. Use ``null`` for in-place actions (``press``,
|
||||||
|
``turn``, ``grasp``, ``release``).
|
||||||
|
|
||||||
|
mistake — a brief one-clause description of any visible failure or
|
||||||
|
recovery during the span (e.g. "dropped the cube and re-grasped",
|
||||||
|
"missed the target on first attempt"). Use ``null`` when the span
|
||||||
|
completes cleanly with no visible recovery.
|
||||||
|
|
||||||
|
Output strictly valid JSON of shape:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"verb": "<one of vocabulary>",
|
||||||
|
"object": "<canonical noun>",
|
||||||
|
"arm": "left" | "right" | "both" | null,
|
||||||
|
"grasp_type": "<one of vocabulary>" | null,
|
||||||
|
"destination": "<canonical noun>" | null,
|
||||||
|
"mistake": "<short description>" | null
|
||||||
|
}}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
You are updating the robot's compressed semantic memory at the boundary of
|
||||||
|
a completed subtask.
|
||||||
|
|
||||||
|
Reference (verbatim from MEM, Torne 2026):
|
||||||
|
"Remove or compress information in the language memory whenever
|
||||||
|
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||||
|
task execution. Specific object attributes (colors, precise quantities of
|
||||||
|
each item) get discarded when their details won't affect subsequent
|
||||||
|
actions. Functional outcomes (where items went, how many) are preserved."
|
||||||
|
|
||||||
|
Episode task: "{episode_task}"
|
||||||
|
Previous memory: {prior_memory}
|
||||||
|
Just-completed subtask: "{completed_subtask}"
|
||||||
|
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||||
|
|
||||||
|
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||||
|
robot has accomplished so far — the running story it would tell itself.
|
||||||
|
|
||||||
|
Authoring rules:
|
||||||
|
- First person, past tense. Every sentence starts with "I": "I picked
|
||||||
|
up...", "I opened...", "I moved to...".
|
||||||
|
- One or two short sentences. Extend the previous memory with the
|
||||||
|
just-completed subtask; do not rewrite it from scratch.
|
||||||
|
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||||
|
drop HOW (grasp details, motions).
|
||||||
|
- Compress completed steps and drop object attributes (colors, exact
|
||||||
|
counts) once they no longer affect the remaining subtasks.
|
||||||
|
|
||||||
|
Example (MEM, Torne 2026):
|
||||||
|
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||||
|
moved to the drawer."
|
||||||
|
After: "I prepared the pot and got the ingredients. I opened the
|
||||||
|
drawer with the masher."
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
You are labeling a teleoperated robot demonstration.
|
||||||
|
|
||||||
|
The user originally asked: "{episode_task}"
|
||||||
|
|
||||||
|
You are shown the entire demonstration as a single video. Watch the
|
||||||
|
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||||
|
the robot performs.
|
||||||
|
|
||||||
|
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||||
|
|
||||||
|
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||||
|
execute end-to-end. A "skill" bundles its own approach motion with
|
||||||
|
its terminal action — do NOT split the approach off as its own
|
||||||
|
subtask. The whole-arm policy already learns to reach as part of
|
||||||
|
every manipulation primitive.
|
||||||
|
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||||
|
these verbs (extend only when none fits):
|
||||||
|
pick up <obj> — approach + grasp + lift in one subtask
|
||||||
|
put <obj> on/in <loc> — transport + release in one subtask
|
||||||
|
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||||
|
push <obj> — contact + linear shove
|
||||||
|
pull <obj> — contact + linear retract
|
||||||
|
turn <knob/dial/handle> — rotary actuation
|
||||||
|
press <button> — single-press contact
|
||||||
|
open <drawer/door/lid> — full open motion
|
||||||
|
close <drawer/door/lid> — full close motion
|
||||||
|
pour <src> into <dst> — tilt + flow
|
||||||
|
insert <obj> into <slot>— alignment + push-fit
|
||||||
|
go to <loc> — ONLY when no grasp / actuation follows
|
||||||
|
(e.g. a pure relocation between phases).
|
||||||
|
If the next subtask grasps something at
|
||||||
|
that location, drop "go to ..." and just
|
||||||
|
write "pick up ..." instead.
|
||||||
|
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||||
|
as standalone subtasks; fold them into the parent composite:
|
||||||
|
"move to X" → fold into "pick up X" (or whatever follows)
|
||||||
|
"reach for X" → fold into "pick up X"
|
||||||
|
"grasp X" → fold into "pick up X"
|
||||||
|
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||||
|
the transport phase of a place)
|
||||||
|
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||||
|
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||||
|
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||||
|
detail (which hand, which grasp point) ONLY when it is needed to
|
||||||
|
disambiguate. Every subtask must begin with one of the verbs
|
||||||
|
above (no leading nouns, no "then", no "first").
|
||||||
|
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||||
|
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||||
|
do not describe it.
|
||||||
|
- Use the exact object nouns from the task above. If the task says
|
||||||
|
"cube", every subtask says "cube" — never switch to "block". If it
|
||||||
|
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||||
|
consistent across the whole episode.
|
||||||
|
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||||
|
"turn red knob", "press start button", "go to sink".
|
||||||
|
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||||
|
must be folded into "pick up blue cube"); "the robot arm moves
|
||||||
|
towards the blue cube" (third person, too long); "carefully pick
|
||||||
|
up the cube" (adverb, article); "release the yellow block"
|
||||||
|
("block" when the task said "cube", and "release" must be folded
|
||||||
|
into a "put"/"place" subtask).
|
||||||
|
- Subtasks are non-overlapping and cover the full episode in order.
|
||||||
|
Choose the cut points yourself based on what you see in the video
|
||||||
|
(gripper open/close events, contact, regrasps, transitions).
|
||||||
|
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||||
|
candidate span would be shorter, merge it into its neighbour
|
||||||
|
rather than emitting it.
|
||||||
|
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||||
|
are preferred over many micro-steps.
|
||||||
|
- Every subtask's [start_time, end_time] must lie within
|
||||||
|
[0.0, {episode_duration}] seconds.
|
||||||
|
|
||||||
|
Output strictly valid JSON of shape:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"subtasks": [
|
||||||
|
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
You are generating structured augmentations of a robot task instruction
|
||||||
|
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||||
|
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||||
|
a specific element of the task while preserving its meaning.
|
||||||
|
|
||||||
|
Original task: "{base_task}"
|
||||||
|
|
||||||
|
Produce variants along five named axes. Each axis has a target count.
|
||||||
|
The whole batch should expose the policy to maximum linguistic diversity
|
||||||
|
WITHOUT changing what the robot is supposed to do.
|
||||||
|
|
||||||
|
Axes and target counts:
|
||||||
|
|
||||||
|
synonym_paraphrase ({n_synonym}):
|
||||||
|
Different wording / verbs / sentence structure. ALL information
|
||||||
|
from the original task is preserved — same object, same arm
|
||||||
|
specification if present, same orientation if present, same grasp
|
||||||
|
if present.
|
||||||
|
|
||||||
|
omit_arm ({n_omit_arm}):
|
||||||
|
Drop the left/right/both arm specification from the task. Skip
|
||||||
|
entirely (emit 0 entries) if the original task does NOT mention an
|
||||||
|
arm. Do not invent an arm specification just to omit it.
|
||||||
|
|
||||||
|
omit_orientation ({n_omit_orientation}):
|
||||||
|
Drop orientation cues (upright, sideways, facing the user,
|
||||||
|
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||||
|
present in the original task.
|
||||||
|
|
||||||
|
omit_grasp_method ({n_omit_grasp_method}):
|
||||||
|
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||||
|
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||||
|
|
||||||
|
combined_omissions ({n_combined}):
|
||||||
|
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||||
|
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||||
|
orientation, grasp_method) appear in the original task.
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
- Each variant MUST preserve the core action and the target object.
|
||||||
|
Do not change which object is involved, the destination, or the
|
||||||
|
high-level action.
|
||||||
|
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||||
|
- Each variant must be DISTINCT from every other variant in the entire
|
||||||
|
output, both within and across axes. Near-duplicates are not allowed.
|
||||||
|
- If an axis cannot reach its target count because the original task
|
||||||
|
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||||
|
axis with paraphrases that belong to a different axis.
|
||||||
|
- Variants should not all start with verbs — vary sentence structure
|
||||||
|
(some imperative, some polite request, some question).
|
||||||
|
|
||||||
|
Output strictly valid JSON of shape:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||||
|
"omit_arm": ["<v1>", "<v2>", ...],
|
||||||
|
"omit_orientation": ["<v1>", ...],
|
||||||
|
"omit_grasp_method": ["<v1>", ...],
|
||||||
|
"combined_omissions": ["<v1>", ...]
|
||||||
|
}}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
You are generating training data for a Hi Robot-style policy. We need
|
||||||
|
{n} alternative phrasings of the same robot task so the policy sees
|
||||||
|
diverse user prompts during training instead of the same canonical
|
||||||
|
string repeated every frame.
|
||||||
|
|
||||||
|
Original task:
|
||||||
|
"{base_task}"
|
||||||
|
|
||||||
|
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||||
|
|
||||||
|
- formality (casual / polite / curt)
|
||||||
|
- verbosity (mostly short imperative; occasional polite request)
|
||||||
|
- word choice (synonyms, different verbs)
|
||||||
|
- sentence structure (imperative / question / suggestion)
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||||
|
Do not change which object is involved, the destination, or the
|
||||||
|
action. Do not add extra steps. Do not invent new objects.
|
||||||
|
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||||
|
markdown, no quotes, no list numbers.
|
||||||
|
- Phrasings must be distinct — no near-duplicates.
|
||||||
|
- Output exactly {n} entries.
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{
|
||||||
|
"rephrasings": [
|
||||||
|
"<phrasing 1>",
|
||||||
|
"<phrasing 2>",
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
The video above shows a robot manipulation episode in full. Look at
|
||||||
|
the entire video and describe in ONE concise sentence what the robot
|
||||||
|
is doing.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- One sentence, in natural English, like a user instruction.
|
||||||
|
- Capture the goal of the demonstration, not low-level motions.
|
||||||
|
Example: "place the yellow cube into the red bin" — not "move the
|
||||||
|
end-effector down 5cm and close the gripper".
|
||||||
|
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||||
|
- Do not invent objects or actions that aren't visible.
|
||||||
|
- Do not output anything other than the JSON object below.
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{
|
||||||
|
"task": "<single concise sentence describing what the robot does in this video>"
|
||||||
|
}}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
The user just asked the robot: "{episode_task}".
|
||||||
|
|
||||||
|
Generate a short verbal acknowledgement the robot would speak back before
|
||||||
|
beginning the task. Style: compact, confident, friendly.
|
||||||
|
|
||||||
|
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||||
|
"OK, starting with the sponge.", "Got it.".
|
||||||
|
|
||||||
|
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{ "text": "<the spoken acknowledgement>" }}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
You are generating training data for a Hi Robot-style hierarchical
|
||||||
|
robot policy. The robot in this demonstration has ALREADY executed
|
||||||
|
every step shown in the video — we cannot retroactively change the
|
||||||
|
action stream. To keep training data consistent with the video, the
|
||||||
|
"interjection" must align with what the robot is *about to do next* in
|
||||||
|
the demonstration, framed as a natural mid-task user request.
|
||||||
|
|
||||||
|
The episode's overall task: "{episode_task}".
|
||||||
|
|
||||||
|
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||||
|
subtask boundary in the demonstration:
|
||||||
|
|
||||||
|
- Subtask the robot just finished: "{prev_subtask}"
|
||||||
|
- Subtask the robot is about to start: "{next_subtask}"
|
||||||
|
- Time into episode: {timestamp:.2f}s
|
||||||
|
|
||||||
|
Write ONE compact interjection the user would naturally say at this
|
||||||
|
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||||
|
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||||
|
Also write the robot's compact verbal acknowledgement.
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
|
||||||
|
- The interjection MUST be consistent with the next subtask. The user
|
||||||
|
cannot ask for something different from what the robot then does in
|
||||||
|
the video. If you're tempted to say "actually skip X" or "do Y
|
||||||
|
instead", DO NOT — those would contradict the demonstration.
|
||||||
|
- The interjection must reference an object, location, or action that
|
||||||
|
is plausible given the visible scene and the next subtask text.
|
||||||
|
- One short phrase or sentence each. Conversational, not robotic.
|
||||||
|
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||||
|
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||||
|
|
||||||
|
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||||
|
- "Now go ahead and {next_subtask}."
|
||||||
|
- "Great, can you {next_subtask} next?"
|
||||||
|
- "{next_subtask}, please."
|
||||||
|
- "Before you continue, please {next_subtask}."
|
||||||
|
- "Looking good — {next_subtask} now."
|
||||||
|
- "Okay, {next_subtask}."
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{
|
||||||
|
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||||
|
"speech": "<short robot acknowledgement>"
|
||||||
|
}}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
You are generating a frame-grounded visual question/answer pair for
|
||||||
|
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||||
|
Policies — both train policies on grounded features such as bounding box
|
||||||
|
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||||
|
|
||||||
|
The frame shows a robot working on: "{episode_task}".
|
||||||
|
|
||||||
|
Question types and the EXACT answer JSON shape required for each:
|
||||||
|
|
||||||
|
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||||
|
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||||
|
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||||
|
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||||
|
|
||||||
|
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||||
|
"point": [x, y]}}
|
||||||
|
|
||||||
|
count => {{"label": "<obj>", "count": <int>,
|
||||||
|
"note": "<optional short note>"}}
|
||||||
|
|
||||||
|
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||||
|
"value": "<observed value>"}}
|
||||||
|
|
||||||
|
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||||
|
"above|below|near>", "object": "<obj>"}}
|
||||||
|
|
||||||
|
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"question": "<short, frame-grounded question>",
|
||||||
|
"answer": <object whose shape matches the schema above>
|
||||||
|
}}
|
||||||
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
274
src/lerobot/annotations/steerable_pipeline/reader.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Datatrove-shaped reader.
|
||||||
|
|
||||||
|
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||||
|
episode containing:
|
||||||
|
|
||||||
|
- ``episode_index``: int
|
||||||
|
- ``frame_timestamps``: tuple[float, ...]
|
||||||
|
- ``frame_indices``: tuple[int, ...]
|
||||||
|
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||||
|
- ``data_path``: pathlib.Path of the source parquet shard
|
||||||
|
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||||
|
|
||||||
|
This shape lets each module operate per-episode without loading all parquet
|
||||||
|
rows into memory at once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.datasets.io_utils import load_tasks
|
||||||
|
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EpisodeRecord:
|
||||||
|
"""Per-episode record yielded by the reader."""
|
||||||
|
|
||||||
|
episode_index: int
|
||||||
|
episode_task: str
|
||||||
|
frame_timestamps: tuple[float, ...]
|
||||||
|
frame_indices: tuple[int, ...]
|
||||||
|
data_path: Path
|
||||||
|
row_offset: int # row offset within the parquet file where this episode starts
|
||||||
|
row_count: int # number of rows for this episode
|
||||||
|
|
||||||
|
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||||
|
# repeat queries from different modules don't re-read the whole shard.
|
||||||
|
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||||
|
|
||||||
|
def frames_df(self): # type: ignore[no-untyped-def]
|
||||||
|
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||||
|
if self._frames_df_cache is None:
|
||||||
|
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||||
|
|
||||||
|
table = pq.read_table(self.data_path)
|
||||||
|
df: pd.DataFrame = table.to_pandas()
|
||||||
|
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||||
|
drop=True
|
||||||
|
)
|
||||||
|
return self._frames_df_cache
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_subtask_spans(
|
||||||
|
rows: Sequence[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
episode_end_t: float | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||||
|
|
||||||
|
Each span's ``end`` is the next span's ``start``. The final span's
|
||||||
|
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||||
|
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||||
|
which is what downstream consumers (memory, interjection boundary
|
||||||
|
selection) expect.
|
||||||
|
|
||||||
|
Used by the ``plan`` module (plan-update pass) and the
|
||||||
|
``interjections`` module (interjection anchoring), which both need the
|
||||||
|
same span shape.
|
||||||
|
"""
|
||||||
|
sorted_rows = sorted(
|
||||||
|
(r for r in rows if r.get("style") == "subtask"),
|
||||||
|
key=lambda r: float(r["timestamp"]),
|
||||||
|
)
|
||||||
|
spans: list[dict[str, Any]] = []
|
||||||
|
for r in sorted_rows:
|
||||||
|
t = float(r["timestamp"])
|
||||||
|
if spans:
|
||||||
|
spans[-1]["end"] = t
|
||||||
|
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||||
|
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||||
|
spans[-1]["end"] = float(episode_end_t)
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||||
|
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||||
|
|
||||||
|
Modules use this when emitting event-style rows so the row's
|
||||||
|
timestamp matches a real parquet frame: event rows must land on an
|
||||||
|
exact frame, otherwise the per-frame event lookup the writer does
|
||||||
|
would never match them.
|
||||||
|
"""
|
||||||
|
if not frame_timestamps:
|
||||||
|
return float(t)
|
||||||
|
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||||
|
return float(nearest)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||||
|
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||||
|
|
||||||
|
Returns an empty dict when the file is absent — the task description is
|
||||||
|
derived later from the video if needed. Reuses the library-level
|
||||||
|
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||||
|
frame indexed by task string with a ``task_index`` column.
|
||||||
|
"""
|
||||||
|
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||||
|
return {}
|
||||||
|
tasks = load_tasks(root)
|
||||||
|
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||||
|
|
||||||
|
|
||||||
|
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||||
|
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||||
|
|
||||||
|
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||||
|
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||||
|
under ``data/`` and groups by ``episode_index``.
|
||||||
|
"""
|
||||||
|
tasks = _load_tasks_lookup(root)
|
||||||
|
data_dir = root / "data"
|
||||||
|
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||||
|
|
||||||
|
only_set = set(only_episodes) if only_episodes is not None else None
|
||||||
|
|
||||||
|
for path in parquet_files:
|
||||||
|
yield from _iter_one_path(path, tasks, only_set)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||||
|
table = pq.read_table(path)
|
||||||
|
names = table.column_names
|
||||||
|
if "episode_index" not in names:
|
||||||
|
return
|
||||||
|
episode_col = table.column("episode_index").to_pylist()
|
||||||
|
timestamp_col = (
|
||||||
|
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||||
|
)
|
||||||
|
frame_col = (
|
||||||
|
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||||
|
)
|
||||||
|
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||||
|
|
||||||
|
def _build(
|
||||||
|
ep: int,
|
||||||
|
start: int,
|
||||||
|
end: int,
|
||||||
|
task_idx: int | None,
|
||||||
|
ts_buf: list[float],
|
||||||
|
fi_buf: list[int],
|
||||||
|
) -> EpisodeRecord | None:
|
||||||
|
if only_set is not None and ep not in only_set:
|
||||||
|
return None
|
||||||
|
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||||
|
return EpisodeRecord(
|
||||||
|
episode_index=ep,
|
||||||
|
episode_task=task,
|
||||||
|
frame_timestamps=tuple(ts_buf),
|
||||||
|
frame_indices=tuple(fi_buf),
|
||||||
|
data_path=path,
|
||||||
|
row_offset=start,
|
||||||
|
row_count=end - start,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_ep: int | None = None
|
||||||
|
start_offset = 0
|
||||||
|
ts_buf: list[float] = []
|
||||||
|
fi_buf: list[int] = []
|
||||||
|
cur_task_idx: int | None = None
|
||||||
|
|
||||||
|
for i, ep in enumerate(episode_col):
|
||||||
|
if cur_ep is None:
|
||||||
|
cur_ep = ep
|
||||||
|
start_offset = i
|
||||||
|
ts_buf = [timestamp_col[i]]
|
||||||
|
fi_buf = [frame_col[i]]
|
||||||
|
cur_task_idx = task_col[i] if task_col is not None else None
|
||||||
|
continue
|
||||||
|
if ep != cur_ep:
|
||||||
|
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||||
|
if rec is not None:
|
||||||
|
yield rec
|
||||||
|
cur_ep = ep
|
||||||
|
start_offset = i
|
||||||
|
ts_buf = [timestamp_col[i]]
|
||||||
|
fi_buf = [frame_col[i]]
|
||||||
|
cur_task_idx = task_col[i] if task_col is not None else None
|
||||||
|
else:
|
||||||
|
ts_buf.append(timestamp_col[i])
|
||||||
|
fi_buf.append(frame_col[i])
|
||||||
|
|
||||||
|
if cur_ep is not None:
|
||||||
|
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||||
|
if rec is not None:
|
||||||
|
yield rec
|
||||||
|
|
||||||
|
|
||||||
|
def gather_data_paths(root: Path) -> list[Path]:
|
||||||
|
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
|
||||||
|
return sorted((root / "data").rglob("*.parquet"))
|
||||||
|
|
||||||
|
|
||||||
|
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
|
||||||
|
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
|
||||||
|
table = pq.read_table(path, columns=["episode_index"])
|
||||||
|
episode_col = table.column("episode_index").to_pylist()
|
||||||
|
out: dict[int, tuple[int, int]] = {}
|
||||||
|
cur_ep: int | None = None
|
||||||
|
start = 0
|
||||||
|
for i, ep in enumerate(episode_col):
|
||||||
|
if cur_ep is None:
|
||||||
|
cur_ep = ep
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
if ep != cur_ep:
|
||||||
|
out[cur_ep] = (start, i - start)
|
||||||
|
cur_ep = ep
|
||||||
|
start = i
|
||||||
|
if cur_ep is not None:
|
||||||
|
out[cur_ep] = (start, len(episode_col) - start)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
|
||||||
|
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
|
||||||
|
n = record.row_count
|
||||||
|
if k <= 0 or n == 0:
|
||||||
|
return []
|
||||||
|
if k >= n:
|
||||||
|
return list(range(n))
|
||||||
|
step = (n - 1) / (k - 1) if k > 1 else 0.0
|
||||||
|
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
|
||||||
|
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
|
||||||
|
for path in gather_data_paths(root):
|
||||||
|
offsets = episode_offsets_per_path(path)
|
||||||
|
if episode_index in offsets:
|
||||||
|
start, count = offsets[episode_index]
|
||||||
|
return path, start, count
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
|
||||||
|
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
|
||||||
|
found = lookup_data_path(root, episode_index)
|
||||||
|
if found is None:
|
||||||
|
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
|
||||||
|
path, start, count = found
|
||||||
|
table = pq.read_table(path, columns=["timestamp"])
|
||||||
|
timestamps = table.column("timestamp").to_pylist()[start : start + count]
|
||||||
|
return path, [float(t) for t in timestamps]
|
||||||
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
104
src/lerobot/annotations/steerable_pipeline/staging.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Per-episode staging.
|
||||||
|
|
||||||
|
Each module writes its raw output as a JSONL file under
|
||||||
|
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||||
|
staging tree and partitions rows into the two language columns.
|
||||||
|
|
||||||
|
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||||
|
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||||
|
appended to. The final dataset format is parquet; staging is just an
|
||||||
|
intermediate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
ModuleName = str
|
||||||
|
|
||||||
|
_MODULES: tuple[ModuleName, ...] = (
|
||||||
|
"plan",
|
||||||
|
"interjections",
|
||||||
|
"vqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EpisodeStaging:
|
||||||
|
"""Filesystem layout for a single episode's staged module outputs."""
|
||||||
|
|
||||||
|
root: Path
|
||||||
|
episode_index: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_dir(self) -> Path:
|
||||||
|
return self.root / f"episode_{self.episode_index:06d}"
|
||||||
|
|
||||||
|
def path_for(self, module: ModuleName) -> Path:
|
||||||
|
if module not in _MODULES:
|
||||||
|
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||||
|
return self.episode_dir / f"{module}.jsonl"
|
||||||
|
|
||||||
|
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||||
|
path = self.path_for(module)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Atomic replace: a crash mid-write would otherwise leave a
|
||||||
|
# half-written JSONL file that ``read()`` would then fail to
|
||||||
|
# parse. Write to a sibling .tmp and rename so the target path
|
||||||
|
# only ever points at a complete file.
|
||||||
|
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
with tmp_path.open("w", encoding="utf-8") as f:
|
||||||
|
for row in rows:
|
||||||
|
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||||
|
f.write("\n")
|
||||||
|
tmp_path.replace(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||||
|
path = self.path_for(module)
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
with path.open(encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
out.append(json.loads(line))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||||
|
return {m: self.read(m) for m in _MODULES}
|
||||||
|
|
||||||
|
def has(self, module: ModuleName) -> bool:
|
||||||
|
return self.path_for(module).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def iter_staged_episodes(root: Path) -> Iterator[int]:
|
||||||
|
"""Yield episode indices for which any staging artifact exists."""
|
||||||
|
if not root.exists():
|
||||||
|
return
|
||||||
|
for child in sorted(root.iterdir()):
|
||||||
|
if child.is_dir() and child.name.startswith("episode_"):
|
||||||
|
try:
|
||||||
|
yield int(child.name.removeprefix("episode_"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
334
src/lerobot/annotations/steerable_pipeline/validator.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Pre-write validation against staged outputs.
|
||||||
|
|
||||||
|
Runs after all three modules have written their per-episode artifacts but
|
||||||
|
*before* the writer rewrites parquet shards. The validator never touches
|
||||||
|
parquet; it only inspects the staging tree and the source frame timestamps
|
||||||
|
exposed by :class:`EpisodeRecord`.
|
||||||
|
|
||||||
|
Checks (per the plan's "Intermediate staging and validation" section):
|
||||||
|
|
||||||
|
- exact timestamp alignment against source frame timestamps
|
||||||
|
- no orphan speech / interjection pairs
|
||||||
|
- plan / memory emission consistency (events have a paired persistent row)
|
||||||
|
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||||
|
attribute / spatial)
|
||||||
|
- every row maps to its correct column under :func:`column_for_style`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.datasets.language import (
|
||||||
|
LANGUAGE_EVENTS,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
column_for_style,
|
||||||
|
is_view_dependent_style,
|
||||||
|
validate_camera_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord
|
||||||
|
from .staging import EpisodeStaging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationReport:
|
||||||
|
"""Outcome of one validation pass across all episodes."""
|
||||||
|
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
episodes_checked: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ok(self) -> bool:
|
||||||
|
return not self.errors
|
||||||
|
|
||||||
|
def add_error(self, message: str) -> None:
|
||||||
|
self.errors.append(message)
|
||||||
|
|
||||||
|
def add_warning(self, message: str) -> None:
|
||||||
|
self.warnings.append(message)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||||
|
|
||||||
|
|
||||||
|
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||||
|
"bbox": {"detections"},
|
||||||
|
"keypoint": {"label", "point_format", "point"},
|
||||||
|
"count": {"label", "count"},
|
||||||
|
"attribute": {"label", "attribute", "value"},
|
||||||
|
"spatial": {"subject", "relation", "object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def classify_vqa_answer(payload: Any) -> str | None:
|
||||||
|
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
keys = set(payload.keys())
|
||||||
|
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||||
|
if required.issubset(keys):
|
||||||
|
return kind
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StagingValidator:
|
||||||
|
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||||
|
|
||||||
|
timestamp_atol: float = 0.0 # exact-match by default
|
||||||
|
dataset_camera_keys: tuple[str, ...] | None = None
|
||||||
|
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||||
|
validator additionally enforces that every view-dependent row's
|
||||||
|
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||||
|
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||||
|
|
||||||
|
def validate(
|
||||||
|
self,
|
||||||
|
records: Sequence[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
) -> ValidationReport:
|
||||||
|
report = ValidationReport()
|
||||||
|
for record in records:
|
||||||
|
self._validate_episode(record, staging_dir, report)
|
||||||
|
report.episodes_checked += 1
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _validate_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
staging_dir: Path,
|
||||||
|
report: ValidationReport,
|
||||||
|
) -> None:
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
staged = staging.read_all()
|
||||||
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
for module_name, rows in staged.items():
|
||||||
|
for row in rows:
|
||||||
|
row = {**row, "_module": module_name}
|
||||||
|
all_rows.append(row)
|
||||||
|
|
||||||
|
frame_ts = set(record.frame_timestamps)
|
||||||
|
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
persistent: list[dict[str, Any]] = []
|
||||||
|
for row in all_rows:
|
||||||
|
self._check_column_routing(row, report, record.episode_index)
|
||||||
|
self._check_camera_field(
|
||||||
|
row, report, record.episode_index, self.dataset_camera_keys
|
||||||
|
)
|
||||||
|
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||||
|
persistent.append(row)
|
||||||
|
else:
|
||||||
|
events.append(row)
|
||||||
|
|
||||||
|
for row in events:
|
||||||
|
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||||
|
|
||||||
|
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||||
|
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||||
|
self._check_vqa_json(events, report, record.episode_index)
|
||||||
|
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||||
|
|
||||||
|
def _check_camera_field(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
dataset_camera_keys: Sequence[str] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||||
|
style = row.get("style")
|
||||||
|
camera = row.get("camera")
|
||||||
|
try:
|
||||||
|
validate_camera_field(style, camera)
|
||||||
|
except ValueError as exc:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if (
|
||||||
|
is_view_dependent_style(style)
|
||||||
|
and dataset_camera_keys
|
||||||
|
and camera not in dataset_camera_keys
|
||||||
|
):
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||||
|
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_vqa_uniqueness_per_frame_camera(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||||
|
counts: dict[tuple[float, str, str], int] = {}
|
||||||
|
for row in events:
|
||||||
|
if row.get("style") != "vqa":
|
||||||
|
continue
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
camera = row.get("camera")
|
||||||
|
role = row.get("role")
|
||||||
|
if ts is None or camera is None or role is None:
|
||||||
|
continue # other validators flag these
|
||||||
|
key = (float(ts), str(camera), str(role))
|
||||||
|
counts[key] = counts.get(key, 0) + 1
|
||||||
|
for (ts, camera, role), n in counts.items():
|
||||||
|
if n > 1:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||||
|
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_column_routing(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
style = row.get("style")
|
||||||
|
module = row.get("_module")
|
||||||
|
try:
|
||||||
|
target_col = column_for_style(style)
|
||||||
|
except ValueError:
|
||||||
|
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||||
|
return
|
||||||
|
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||||
|
)
|
||||||
|
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_event_timestamp_alignment(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
frame_ts: set[float],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
if ts is None:
|
||||||
|
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||||
|
return
|
||||||
|
if self.timestamp_atol == 0.0:
|
||||||
|
if float(ts) not in frame_ts:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_speech_interjection_pairs(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
speech_ts: dict[float, int] = {}
|
||||||
|
interjection_ts: dict[float, int] = {}
|
||||||
|
for row in events:
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
if ts is None:
|
||||||
|
continue
|
||||||
|
ts_f = float(ts)
|
||||||
|
if row.get("style") is None and row.get("role") == "assistant":
|
||||||
|
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||||
|
if row.get("style") == "interjection":
|
||||||
|
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||||
|
|
||||||
|
for ts in interjection_ts:
|
||||||
|
if ts not in speech_ts:
|
||||||
|
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||||
|
|
||||||
|
def _check_plan_memory_consistency(
|
||||||
|
self,
|
||||||
|
persistent: Sequence[dict[str, Any]],
|
||||||
|
events: Sequence[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||||
|
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||||
|
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||||
|
interjection_ts = sorted(
|
||||||
|
{
|
||||||
|
float(r["timestamp"])
|
||||||
|
for r in events
|
||||||
|
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if persistent and not plan_ts:
|
||||||
|
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||||
|
# every interjection should have a same-timestamp plan refresh
|
||||||
|
for ts in interjection_ts:
|
||||||
|
if ts not in set(plan_ts):
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||||
|
)
|
||||||
|
# memory should be emitted at subtask boundaries (subset relation)
|
||||||
|
if memory_ts and subtask_ts:
|
||||||
|
mem_set = set(memory_ts)
|
||||||
|
sub_set = set(subtask_ts)
|
||||||
|
stray = sorted(mem_set - sub_set)
|
||||||
|
if stray:
|
||||||
|
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||||
|
|
||||||
|
def _check_vqa_json(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
for row in events:
|
||||||
|
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||||
|
continue
|
||||||
|
content = row.get("content")
|
||||||
|
if content is None:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = json.loads(content)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
shape = classify_vqa_answer(payload)
|
||||||
|
if shape is None:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||||
|
)
|
||||||
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
703
src/lerobot/annotations/steerable_pipeline/vlm_client.py
Normal file
@@ -0,0 +1,703 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Shared Qwen-VL client.
|
||||||
|
|
||||||
|
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||||
|
available (high throughput, JSON-guided decoding); transformers is the
|
||||||
|
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||||
|
into a real model.
|
||||||
|
|
||||||
|
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||||
|
|
||||||
|
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||||
|
- requests JSON output (``json_mode=True`` enables guided decoding when the
|
||||||
|
backend supports it),
|
||||||
|
- batches requests transparently,
|
||||||
|
- and reprompts once on a JSON parse failure with an inline correction
|
||||||
|
message before raising.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from .config import VlmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VlmClient(Protocol):
|
||||||
|
"""Protocol every backend must implement."""
|
||||||
|
|
||||||
|
def generate_json(
|
||||||
|
self,
|
||||||
|
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Generate one JSON-decoded response per messages list."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubVlmClient:
|
||||||
|
"""Deterministic stub used in unit tests.
|
||||||
|
|
||||||
|
A test passes a callable that maps the *last user message text* (or, if
|
||||||
|
that is empty, the full message list) to a JSON-serializable response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||||
|
|
||||||
|
def generate_json(
|
||||||
|
self,
|
||||||
|
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return [self.responder(list(messages)) for messages in messages_batch]
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_to_json(text: str) -> Any:
|
||||||
|
text = text.strip()
|
||||||
|
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||||
|
while "<think>" in text and "</think>" in text:
|
||||||
|
start = text.find("<think>")
|
||||||
|
end = text.find("</think>", start) + len("</think>")
|
||||||
|
text = (text[:start] + text[end:]).strip()
|
||||||
|
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||||
|
if text.startswith("```"):
|
||||||
|
first = text.find("\n")
|
||||||
|
last = text.rfind("```")
|
||||||
|
if first != -1 and last != -1 and last > first:
|
||||||
|
text = text[first + 1 : last].strip()
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
# Fall back to extracting the first balanced {...} block.
|
||||||
|
obj_text = _extract_first_json_object(text)
|
||||||
|
if obj_text is None:
|
||||||
|
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||||
|
return json.loads(obj_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_first_json_object(text: str) -> str | None:
|
||||||
|
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||||
|
string literals. Returns ``None`` if no balanced block is found."""
|
||||||
|
start = text.find("{")
|
||||||
|
if start < 0:
|
||||||
|
return None
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
ch = text[i]
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
continue
|
||||||
|
if ch == "\\":
|
||||||
|
escape = True
|
||||||
|
continue
|
||||||
|
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||||
|
# above already handled and reset it.
|
||||||
|
if ch == '"':
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
if ch == "{":
|
||||||
|
depth += 1
|
||||||
|
elif ch == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return text[start : i + 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _GenericTextClient:
|
||||||
|
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||||
|
|
||||||
|
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||||
|
config: VlmConfig
|
||||||
|
|
||||||
|
def generate_json(
|
||||||
|
self,
|
||||||
|
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||||
|
temp = temperature if temperature is not None else self.config.temperature
|
||||||
|
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||||
|
out: list[Any] = []
|
||||||
|
for messages, text in zip(messages_batch, raw, strict=True):
|
||||||
|
try:
|
||||||
|
out.append(_strip_to_json(text))
|
||||||
|
continue
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
retry = list(messages) + [
|
||||||
|
{"role": "assistant", "content": text},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Your previous reply was not valid JSON. "
|
||||||
|
"Reply with strictly valid JSON, no prose, no fences."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||||
|
try:
|
||||||
|
out.append(_strip_to_json(retry_text))
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
# After retry: log preview and return None instead of crashing
|
||||||
|
# the whole pipeline. Modules treat None as "skip".
|
||||||
|
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||||
|
print(
|
||||||
|
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
out.append(None)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||||
|
"""Build the shared VLM client per the configured backend.
|
||||||
|
|
||||||
|
For ``stub``, callers should construct :class:`StubVlmClient` directly with
|
||||||
|
a responder callable. ``stub`` here is rejected to make accidental misuse
|
||||||
|
obvious.
|
||||||
|
"""
|
||||||
|
if config.backend == "stub":
|
||||||
|
raise ValueError(
|
||||||
|
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||||
|
)
|
||||||
|
if config.backend == "vllm":
|
||||||
|
return _make_vllm_client(config)
|
||||||
|
if config.backend == "transformers":
|
||||||
|
return _make_transformers_client(config)
|
||||||
|
if config.backend == "openai":
|
||||||
|
return _make_openai_client(config)
|
||||||
|
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||||
|
) from exc
|
||||||
|
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||||
|
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||||
|
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||||
|
# convolution kernels — slower but functional.
|
||||||
|
if os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||||
|
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||||
|
|
||||||
|
_torch.backends.cudnn.enabled = False
|
||||||
|
llm_kwargs: dict[str, Any] = {
|
||||||
|
"model": config.model_id,
|
||||||
|
"tensor_parallel_size": config.tensor_parallel_size,
|
||||||
|
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||||
|
"trust_remote_code": config.trust_remote_code,
|
||||||
|
}
|
||||||
|
if config.max_model_len is not None:
|
||||||
|
llm_kwargs["max_model_len"] = config.max_model_len
|
||||||
|
llm = LLM(**llm_kwargs)
|
||||||
|
|
||||||
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
|
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||||
|
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||||
|
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||||
|
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||||
|
# ``llm.chat`` handles chat-template application + multimodal input
|
||||||
|
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||||
|
# does not.
|
||||||
|
outputs = llm.chat([list(m) for m in batch], params)
|
||||||
|
return [o.outputs[0].text for o in outputs]
|
||||||
|
|
||||||
|
return _GenericTextClient(_gen, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||||
|
try:
|
||||||
|
import torch # type: ignore[import-not-found]
|
||||||
|
import transformers # type: ignore[import-not-found]
|
||||||
|
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||||
|
auto_cls = getattr(transformers, "AutoModelForImageTextToText", None) or getattr(
|
||||||
|
transformers, "AutoModelForVision2Seq", None
|
||||||
|
)
|
||||||
|
if auto_cls is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||||
|
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||||
|
"for VL models."
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
|
||||||
|
use_accelerate = os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||||
|
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||||
|
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||||
|
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||||
|
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||||
|
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||||
|
if use_accelerate:
|
||||||
|
model = auto_cls.from_pretrained(
|
||||||
|
config.model_id,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=config.trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||||
|
|
||||||
|
model = auto_cls.from_pretrained(
|
||||||
|
config.model_id,
|
||||||
|
torch_dtype=_torch.bfloat16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=config.trust_remote_code,
|
||||||
|
)
|
||||||
|
model = model.to("cuda")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
|
outs: list[str] = []
|
||||||
|
for messages in batch:
|
||||||
|
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
inputs = processor(text=[text], return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
gen = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tok,
|
||||||
|
temperature=temp,
|
||||||
|
do_sample=temp > 0.0,
|
||||||
|
)
|
||||||
|
decoded = processor.batch_decode(
|
||||||
|
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
|
||||||
|
)[0]
|
||||||
|
outs.append(decoded)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
return _GenericTextClient(_gen, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||||
|
"""Backend that talks to any OpenAI-compatible server.
|
||||||
|
|
||||||
|
Compatible with ``vllm serve``, ``transformers serve``,
|
||||||
|
``ktransformers serve``, and hosted endpoints. By default the server
|
||||||
|
is expected to be already running. Set ``auto_serve=True`` to have
|
||||||
|
this client spawn one (default: ``transformers serve``), wait until
|
||||||
|
it's ready, and tear it down on process exit.
|
||||||
|
|
||||||
|
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||||
|
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||||
|
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||||
|
multi-frame ``video_url`` items where supported.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from openai import OpenAI # type: ignore[import-not-found]
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
api_base = config.api_base
|
||||||
|
api_key = config.api_key
|
||||||
|
auto_serve = config.auto_serve
|
||||||
|
api_bases: list[str] = [api_base]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||||
|
f"api_base={api_base} auto_serve={auto_serve}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if auto_serve:
|
||||||
|
if config.parallel_servers > 1:
|
||||||
|
print(
|
||||||
|
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
api_bases = _spawn_parallel_inference_servers(config)
|
||||||
|
elif _server_is_up(api_base):
|
||||||
|
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||||
|
else:
|
||||||
|
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||||
|
api_base = _spawn_inference_server(config)
|
||||||
|
api_bases = [api_base]
|
||||||
|
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||||
|
|
||||||
|
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||||
|
# round-robin counter for parallel mode
|
||||||
|
rr_counter = {"i": 0}
|
||||||
|
|
||||||
|
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||||
|
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||||
|
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||||
|
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
rr_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||||
|
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": config.model_id,
|
||||||
|
"messages": api_messages,
|
||||||
|
"max_tokens": max_tok,
|
||||||
|
"temperature": temp,
|
||||||
|
}
|
||||||
|
extra_body: dict[str, Any] = {}
|
||||||
|
if send_mm_kwargs and mm_kwargs:
|
||||||
|
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||||
|
if config.chat_template_kwargs:
|
||||||
|
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||||
|
if extra_body:
|
||||||
|
kwargs["extra_body"] = extra_body
|
||||||
|
with rr_lock:
|
||||||
|
chosen = clients[rr_counter["i"] % len(clients)]
|
||||||
|
rr_counter["i"] += 1
|
||||||
|
response = chosen.chat.completions.create(**kwargs)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|
||||||
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
|
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||||
|
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||||
|
# Parallel fan-out — vllm batches these on the server side.
|
||||||
|
max_workers = min(config.client_concurrency, len(batch))
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
|
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||||
|
return [f.result() for f in futures]
|
||||||
|
|
||||||
|
return _GenericTextClient(_gen, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||||
|
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||||
|
|
||||||
|
Each replica:
|
||||||
|
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||||
|
- listens on ``serve_port + i``
|
||||||
|
- is shut down via the same atexit hook as the single-server path
|
||||||
|
|
||||||
|
Returns the list of ``api_base`` URLs the client should round-robin
|
||||||
|
across.
|
||||||
|
"""
|
||||||
|
n = config.parallel_servers
|
||||||
|
api_bases: list[str] = []
|
||||||
|
procs: list[subprocess.Popen] = []
|
||||||
|
ready_events: list[threading.Event] = []
|
||||||
|
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||||
|
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||||
|
# "Starting vLLM API server" line and the route-listing line. The
|
||||||
|
# HTTP probe below is the ultimate fallback.
|
||||||
|
ready_markers = (
|
||||||
|
"Uvicorn running",
|
||||||
|
"Application startup complete",
|
||||||
|
"Starting vLLM API server",
|
||||||
|
"Available routes are",
|
||||||
|
)
|
||||||
|
# Single lock for all server-stream threads so multibyte chars from
|
||||||
|
# different servers don't interleave and tear UTF-8 sequences.
|
||||||
|
print_lock = threading.Lock()
|
||||||
|
|
||||||
|
base_cmd = config.serve_command or (
|
||||||
|
f"vllm serve {shlex.quote(config.model_id)} "
|
||||||
|
f"--tensor-parallel-size 1 "
|
||||||
|
f"--max-model-len {config.max_model_len or 32768} "
|
||||||
|
f"--uvicorn-log-level warning"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||||
|
for i in range(n):
|
||||||
|
port = config.serve_port + i
|
||||||
|
gpu = i % num_gpus
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||||
|
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||||
|
api_base = f"http://localhost:{port}/v1"
|
||||||
|
api_bases.append(api_base)
|
||||||
|
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
shlex.split(cmd),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
procs.append(proc)
|
||||||
|
ready = threading.Event()
|
||||||
|
ready_events.append(ready)
|
||||||
|
|
||||||
|
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||||
|
# Read whole lines and emit each line atomically under the
|
||||||
|
# shared print_lock so output from N servers stays readable.
|
||||||
|
assert p.stdout is not None
|
||||||
|
for line in iter(p.stdout.readline, ""):
|
||||||
|
with print_lock:
|
||||||
|
sys.stdout.write(f"[server-{idx}] {line}")
|
||||||
|
if not line.endswith(("\n", "\r")):
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
if any(m in line for m in ready_markers):
|
||||||
|
ev.set()
|
||||||
|
|
||||||
|
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||||
|
|
||||||
|
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||||
|
while not ev.is_set() and p.poll() is None:
|
||||||
|
if _server_is_up(base):
|
||||||
|
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||||
|
ev.set()
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||||
|
|
||||||
|
def _shutdown() -> None:
|
||||||
|
for i, p in enumerate(procs):
|
||||||
|
if p.poll() is None:
|
||||||
|
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||||
|
p.send_signal(signal.SIGINT)
|
||||||
|
for p in procs:
|
||||||
|
try:
|
||||||
|
p.wait(timeout=15)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
p.kill()
|
||||||
|
p.wait(timeout=5)
|
||||||
|
|
||||||
|
atexit.register(_shutdown)
|
||||||
|
|
||||||
|
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||||
|
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||||
|
for i, p in enumerate(procs):
|
||||||
|
if p.poll() is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||||
|
)
|
||||||
|
time.sleep(2)
|
||||||
|
if any(not ev.is_set() for ev in ready_events):
|
||||||
|
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||||
|
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||||
|
return api_bases
|
||||||
|
|
||||||
|
|
||||||
|
def _server_is_up(api_base: str) -> bool:
|
||||||
|
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||||
|
url = api_base.rstrip("/") + "/models"
|
||||||
|
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||||
|
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||||
|
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||||
|
# cannot reach this code path.
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||||
|
return resp.status == 200
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||||
|
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||||
|
accepts ``/v1/models``, and register a shutdown hook.
|
||||||
|
|
||||||
|
Streams the server's stdout/stderr to the parent terminal in
|
||||||
|
real-time on a background thread so users can see model-load
|
||||||
|
progress and errors as they happen.
|
||||||
|
|
||||||
|
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||||
|
"""
|
||||||
|
cmd = config.serve_command
|
||||||
|
if not cmd:
|
||||||
|
cmd = (
|
||||||
|
f"transformers serve {shlex.quote(config.model_id)} "
|
||||||
|
f"--port {config.serve_port} --continuous-batching"
|
||||||
|
)
|
||||||
|
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||||
|
print(f"[server] launching: {cmd}", flush=True)
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
shlex.split(cmd),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Watch the server output for the uvicorn readiness banner. This is
|
||||||
|
# more reliable than polling /v1/models because transformers serve
|
||||||
|
# rescans its cache on every model-list request, which can exceed
|
||||||
|
# the urllib timeout and trigger an infinite probe loop.
|
||||||
|
ready_event = threading.Event()
|
||||||
|
# See _spawn_parallel_inference_servers for why we accept these.
|
||||||
|
ready_markers = (
|
||||||
|
"Uvicorn running",
|
||||||
|
"Application startup complete",
|
||||||
|
"Starting vLLM API server",
|
||||||
|
"Available routes are",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _probe() -> None:
|
||||||
|
while not ready_event.is_set() and proc.poll() is None:
|
||||||
|
if _server_is_up(api_base):
|
||||||
|
print("[server] ready (http probe)", flush=True)
|
||||||
|
ready_event.set()
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
threading.Thread(target=_probe, daemon=True).start()
|
||||||
|
|
||||||
|
def _stream_output() -> None:
|
||||||
|
# Read raw chunks instead of iterating lines so tqdm progress
|
||||||
|
# bars (which overwrite using \r) flush in real time.
|
||||||
|
assert proc.stdout is not None
|
||||||
|
buf = ""
|
||||||
|
prefix_started = False
|
||||||
|
while True:
|
||||||
|
ch = proc.stdout.read(1)
|
||||||
|
if ch == "":
|
||||||
|
# process exited; flush any tail
|
||||||
|
if buf:
|
||||||
|
sys.stdout.write(buf)
|
||||||
|
sys.stdout.flush()
|
||||||
|
return
|
||||||
|
if not prefix_started:
|
||||||
|
sys.stdout.write("[server] ")
|
||||||
|
prefix_started = True
|
||||||
|
sys.stdout.write(ch)
|
||||||
|
sys.stdout.flush()
|
||||||
|
buf += ch
|
||||||
|
if ch in ("\n", "\r"):
|
||||||
|
if any(marker in buf for marker in ready_markers):
|
||||||
|
ready_event.set()
|
||||||
|
buf = ""
|
||||||
|
prefix_started = False
|
||||||
|
|
||||||
|
threading.Thread(target=_stream_output, daemon=True).start()
|
||||||
|
|
||||||
|
def _shutdown() -> None:
|
||||||
|
if proc.poll() is None:
|
||||||
|
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||||
|
proc.send_signal(signal.SIGINT)
|
||||||
|
try:
|
||||||
|
proc.wait(timeout=15)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
proc.wait(timeout=5)
|
||||||
|
|
||||||
|
atexit.register(_shutdown)
|
||||||
|
|
||||||
|
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
if proc.poll() is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||||
|
f"See [server] log lines above for the cause."
|
||||||
|
)
|
||||||
|
if ready_event.wait(timeout=2):
|
||||||
|
return api_base
|
||||||
|
proc.terminate()
|
||||||
|
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_openai_messages(
|
||||||
|
messages: Sequence[dict[str, Any]],
|
||||||
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||||
|
"""Convert internal messages to OpenAI chat format.
|
||||||
|
|
||||||
|
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||||
|
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||||
|
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||||
|
inside the content blocks (which transformers serve rejects).
|
||||||
|
|
||||||
|
File-URL video blocks are inlined as base64 data URLs.
|
||||||
|
"""
|
||||||
|
out_messages: list[dict[str, Any]] = []
|
||||||
|
mm_kwargs: dict[str, Any] = {}
|
||||||
|
for message in messages:
|
||||||
|
content = message.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
out_messages.append({"role": message["role"], "content": content})
|
||||||
|
continue
|
||||||
|
out_blocks: list[dict[str, Any]] = []
|
||||||
|
for block in content:
|
||||||
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
|
if block_type == "text":
|
||||||
|
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||||
|
elif block_type == "image":
|
||||||
|
out_blocks.append(
|
||||||
|
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||||
|
)
|
||||||
|
elif block_type == "video":
|
||||||
|
frames = block.get("video", [])
|
||||||
|
for img in frames:
|
||||||
|
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||||
|
elif block_type == "video_url":
|
||||||
|
video_url = dict(block["video_url"])
|
||||||
|
url = video_url.get("url", "")
|
||||||
|
if url.startswith("file://"):
|
||||||
|
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||||
|
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||||
|
fps = block.get("fps")
|
||||||
|
if fps is not None:
|
||||||
|
mm_kwargs["fps"] = fps
|
||||||
|
else:
|
||||||
|
out_blocks.append(block)
|
||||||
|
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||||
|
return out_messages, mm_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _file_to_data_url(path: str) -> str:
|
||||||
|
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||||
|
return f"data:video/mp4;base64,{b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_to_data_url(image: Any) -> str:
|
||||||
|
"""Encode a PIL.Image as a base64 data URL."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
image.save(buf, format="PNG")
|
||||||
|
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
return f"data:image/png;base64,{b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||||
|
"""Pass-through hook used by the vllm backend.
|
||||||
|
|
||||||
|
vllm exposes its own multimodal entry points that vary by version; for the
|
||||||
|
base flow we simply forward the raw message list and let the caller's
|
||||||
|
custom backend handle templating. Real deployments override this.
|
||||||
|
"""
|
||||||
|
return list(messages)
|
||||||
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
356
src/lerobot/annotations/steerable_pipeline/writer.py
Normal file
@@ -0,0 +1,356 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Final parquet rewrite.
|
||||||
|
|
||||||
|
For every episode the writer:
|
||||||
|
|
||||||
|
1. reads the staged module outputs,
|
||||||
|
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||||
|
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||||
|
3. sorts each slice deterministically,
|
||||||
|
4. broadcasts the persistent slice across every frame in the episode,
|
||||||
|
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||||
|
exactly equals that frame's timestamp,
|
||||||
|
6. drops the legacy ``subtask_index`` column,
|
||||||
|
7. writes the parquet shard back in place.
|
||||||
|
|
||||||
|
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||||
|
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||||
|
struct for every speech atom. The tool *schema* (the description
|
||||||
|
of the ``say`` function and its parameters) is a fixed code constant —
|
||||||
|
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||||
|
it directly rather than reading a redundant per-row column.
|
||||||
|
|
||||||
|
Invariants enforced here (and re-checked by the validator):
|
||||||
|
|
||||||
|
- per-episode persistent slice is byte-identical across every frame;
|
||||||
|
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||||
|
(timestamps come straight from the source parquet — never recomputed);
|
||||||
|
- every row passes ``column_for_style(style)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.datasets.language import (
|
||||||
|
EVENT_ONLY_STYLES,
|
||||||
|
LANGUAGE_EVENTS,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
PERSISTENT_STYLES,
|
||||||
|
column_for_style,
|
||||||
|
validate_camera_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord
|
||||||
|
from .staging import EpisodeStaging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Tool schema constants live in lerobot.datasets.language — single
|
||||||
|
# source of truth. Re-exported here so existing imports
|
||||||
|
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||||
|
# keep working.
|
||||||
|
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||||
|
|
||||||
|
|
||||||
|
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||||
|
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||||
|
|
||||||
|
|
||||||
|
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||||
|
# events are bucketed per-frame, but within a frame we still want determinism
|
||||||
|
return (
|
||||||
|
row.get("style") or "",
|
||||||
|
row.get("role") or "",
|
||||||
|
row.get("camera") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Coerce a staged row into the persistent column's struct shape."""
|
||||||
|
style = row.get("style")
|
||||||
|
if style not in PERSISTENT_STYLES:
|
||||||
|
raise ValueError(
|
||||||
|
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||||
|
"row would be misrouted under column_for_style()"
|
||||||
|
)
|
||||||
|
if "timestamp" not in row:
|
||||||
|
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||||
|
if "role" not in row:
|
||||||
|
# Surface a friendly error from the writer rather than letting
|
||||||
|
# the raw KeyError bubble out of the dict access below — modules
|
||||||
|
# are expected to always emit ``role``, but the validator
|
||||||
|
# currently doesn't check this so a future bug would otherwise
|
||||||
|
# be hard to triage.
|
||||||
|
raise ValueError(f"persistent row missing role: {row!r}")
|
||||||
|
camera = row.get("camera")
|
||||||
|
validate_camera_field(style, camera)
|
||||||
|
return {
|
||||||
|
"role": str(row["role"]),
|
||||||
|
"content": None if row.get("content") is None else str(row["content"]),
|
||||||
|
"style": style,
|
||||||
|
"timestamp": float(row["timestamp"]),
|
||||||
|
"camera": None if camera is None else str(camera),
|
||||||
|
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||||
|
style = row.get("style")
|
||||||
|
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||||
|
raise ValueError(
|
||||||
|
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||||
|
)
|
||||||
|
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||||
|
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||||
|
if "role" not in row:
|
||||||
|
raise ValueError(f"event row missing role: {row!r}")
|
||||||
|
camera = row.get("camera")
|
||||||
|
validate_camera_field(style, camera)
|
||||||
|
return {
|
||||||
|
"role": str(row["role"]),
|
||||||
|
"content": None if row.get("content") is None else str(row["content"]),
|
||||||
|
"style": style,
|
||||||
|
"camera": None if camera is None else str(camera),
|
||||||
|
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||||
|
return list(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||||
|
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||||
|
has_content = row.get("content") is not None
|
||||||
|
has_tools = row.get("tool_calls") is not None
|
||||||
|
if not (has_content or has_tools):
|
||||||
|
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||||
|
if row.get("style") is None and not has_tools:
|
||||||
|
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||||
|
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||||
|
if row.get("style") is not None:
|
||||||
|
return # not a speech atom
|
||||||
|
if row.get("role") != "assistant":
|
||||||
|
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||||
|
if row.get("content") is not None:
|
||||||
|
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||||
|
tool_calls = row.get("tool_calls")
|
||||||
|
if not tool_calls or not isinstance(tool_calls, list):
|
||||||
|
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||||
|
first = tool_calls[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||||
|
if first.get("type") != "function":
|
||||||
|
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||||
|
fn = first.get("function") or {}
|
||||||
|
if fn.get("name") != "say":
|
||||||
|
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||||
|
args = fn.get("arguments") or {}
|
||||||
|
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||||
|
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LanguageColumnsWriter:
|
||||||
|
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||||
|
|
||||||
|
drop_existing_subtask_index: bool = True
|
||||||
|
|
||||||
|
def write_all(
|
||||||
|
self,
|
||||||
|
records: Sequence[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
root: Path,
|
||||||
|
) -> list[Path]:
|
||||||
|
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||||
|
for record in records:
|
||||||
|
episodes_by_path[record.data_path].append(record)
|
||||||
|
|
||||||
|
written: list[Path] = []
|
||||||
|
for path, eps in episodes_by_path.items():
|
||||||
|
self._rewrite_one(path, eps, staging_dir, root)
|
||||||
|
written.append(path)
|
||||||
|
return written
|
||||||
|
|
||||||
|
def _rewrite_one(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
episodes: Sequence[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
root: Path,
|
||||||
|
) -> None:
|
||||||
|
table = pq.read_table(path)
|
||||||
|
n_rows = table.num_rows
|
||||||
|
|
||||||
|
# Ensure we cover every episode in the file. Episodes that don't have
|
||||||
|
# staging artifacts are passed through with empty annotation lists —
|
||||||
|
# this keeps the writer idempotent and safe for partial reruns.
|
||||||
|
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||||
|
for record in episodes:
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
staged_per_ep[record.episode_index] = staging.read_all()
|
||||||
|
|
||||||
|
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||||
|
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||||
|
|
||||||
|
for ep_index, ep_staged in staged_per_ep.items():
|
||||||
|
persistent_rows: list[dict[str, Any]] = []
|
||||||
|
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||||
|
for _module_name, rows in ep_staged.items():
|
||||||
|
for row in rows:
|
||||||
|
style = row.get("style")
|
||||||
|
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||||
|
persistent_rows.append(row)
|
||||||
|
else:
|
||||||
|
event_rows.append(row)
|
||||||
|
|
||||||
|
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||||
|
normalized_persistent = []
|
||||||
|
for r in persistent_rows:
|
||||||
|
_validate_atom_invariants(r)
|
||||||
|
_validate_speech_atom(r)
|
||||||
|
normalized_persistent.append(_normalize_persistent_row(r))
|
||||||
|
persistent_by_ep[ep_index] = normalized_persistent
|
||||||
|
|
||||||
|
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||||
|
for r in event_rows:
|
||||||
|
_validate_atom_invariants(r)
|
||||||
|
_validate_speech_atom(r)
|
||||||
|
ts = float(r["timestamp"])
|
||||||
|
buckets[ts].append(_normalize_event_row(r))
|
||||||
|
for ts in list(buckets.keys()):
|
||||||
|
buckets[ts].sort(key=_row_event_sort_key)
|
||||||
|
events_by_ep_ts[ep_index] = buckets
|
||||||
|
|
||||||
|
episode_col = (
|
||||||
|
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||||
|
)
|
||||||
|
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||||
|
if episode_col is None or ts_col is None:
|
||||||
|
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||||
|
|
||||||
|
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||||
|
per_row_events: list[list[dict[str, Any]]] = []
|
||||||
|
for i in range(n_rows):
|
||||||
|
ep = episode_col[i]
|
||||||
|
ts = float(ts_col[i])
|
||||||
|
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||||
|
buckets = events_by_ep_ts.get(ep, {})
|
||||||
|
per_row_events.append(buckets.get(ts, []))
|
||||||
|
|
||||||
|
new_table = self._materialize_table(
|
||||||
|
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||||
|
)
|
||||||
|
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||||
|
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||||
|
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||||
|
# Windows when source and target sit on the same filesystem.
|
||||||
|
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
pq.write_table(new_table, tmp_path)
|
||||||
|
tmp_path.replace(path)
|
||||||
|
|
||||||
|
def _materialize_table(
|
||||||
|
self,
|
||||||
|
table: pa.Table,
|
||||||
|
persistent: list[list[dict[str, Any]]],
|
||||||
|
events: list[list[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
drop_old: bool,
|
||||||
|
) -> pa.Table:
|
||||||
|
cols = []
|
||||||
|
names = []
|
||||||
|
for name in table.column_names:
|
||||||
|
if drop_old and name == "subtask_index":
|
||||||
|
continue
|
||||||
|
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||||
|
continue # we'll re-add canonical versions
|
||||||
|
# Strip any legacy ``tools`` column previously emitted by older
|
||||||
|
# writers — the schema no longer uses it (constant lives in
|
||||||
|
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||||
|
if name == "tools":
|
||||||
|
continue
|
||||||
|
cols.append(table.column(name))
|
||||||
|
names.append(name)
|
||||||
|
|
||||||
|
# We let pyarrow infer struct/list schema rather than passing the
|
||||||
|
# canonical type from `lerobot.datasets.language` directly: that type
|
||||||
|
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||||
|
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||||
|
# current pyarrow versions. The inferred schema round-trips through
|
||||||
|
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||||
|
# exercises the same flow.
|
||||||
|
persistent_arr = pa.array(persistent)
|
||||||
|
events_arr = pa.array(events)
|
||||||
|
|
||||||
|
cols.extend([persistent_arr, events_arr])
|
||||||
|
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||||
|
|
||||||
|
return pa.Table.from_arrays(cols, names=names)
|
||||||
|
|
||||||
|
|
||||||
|
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||||
|
"""Build a canonical speech tool-call atom for the events column."""
|
||||||
|
return {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"style": None,
|
||||||
|
"timestamp": float(timestamp),
|
||||||
|
"camera": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "say",
|
||||||
|
"arguments": {"text": text},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_rows_for_writer(
|
||||||
|
rows: Iterable[dict[str, Any]],
|
||||||
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
"""Helper used by tests/validators to partition a flat row list into
|
||||||
|
(persistent_rows, event_rows) using ``column_for_style``.
|
||||||
|
"""
|
||||||
|
persistent: list[dict[str, Any]] = []
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
for row in rows:
|
||||||
|
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||||
|
persistent.append(row)
|
||||||
|
else:
|
||||||
|
events.append(row)
|
||||||
|
return persistent, events
|
||||||
@@ -46,7 +46,7 @@ CORE_STYLES = {
|
|||||||
EXTENDED_STYLES: set[str] = set()
|
EXTENDED_STYLES: set[str] = set()
|
||||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||||
|
|
||||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug", "action_record"}
|
||||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||||
|
|
||||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||||
|
|||||||
200
src/lerobot/scripts/lerobot_annotate.py
Normal file
200
src/lerobot/scripts/lerobot_annotate.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``lerobot-annotate`` — populate ``language_persistent`` and
|
||||||
|
``language_events`` columns on a LeRobot dataset.
|
||||||
|
|
||||||
|
Annotations live directly in ``data/chunk-*/file-*.parquet``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
uv run lerobot-annotate \\
|
||||||
|
--root=/path/to/dataset \\
|
||||||
|
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
For distributed runs, see ``examples/annotations/run_hf_job.py``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||||
|
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||||
|
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||||
|
from lerobot.annotations.steerable_pipeline.modules import (
|
||||||
|
GeneralVqaModule,
|
||||||
|
InterjectionsAndSpeechModule,
|
||||||
|
PlanSubtasksMemoryModule,
|
||||||
|
)
|
||||||
|
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||||
|
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||||
|
from lerobot.configs import parser
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
|
||||||
|
if cfg.root is not None:
|
||||||
|
return Path(cfg.root)
|
||||||
|
if cfg.repo_id is not None:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
|
||||||
|
raise ValueError("Either --root or --repo_id must be provided.")
|
||||||
|
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||||
|
"""Run the steerable annotation pipeline against a dataset."""
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||||
|
root = _resolve_root(cfg)
|
||||||
|
logger.info("annotate: root=%s", root)
|
||||||
|
|
||||||
|
vlm = make_vlm_client(cfg.vlm)
|
||||||
|
frame_provider = make_frame_provider(
|
||||||
|
root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend
|
||||||
|
)
|
||||||
|
# Surface the resolved cameras up front so a silent vqa-module no-op
|
||||||
|
# is obvious in job output rather than discovered post-hoc by counting
|
||||||
|
# parquet rows.
|
||||||
|
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||||
|
logger.info(
|
||||||
|
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||||
|
getattr(frame_provider, "camera_key", None),
|
||||||
|
cam_keys,
|
||||||
|
)
|
||||||
|
if cfg.vqa.enabled and not cam_keys:
|
||||||
|
logger.warning(
|
||||||
|
"annotate: the vqa module is enabled but no cameras were "
|
||||||
|
"resolved — it will produce zero VQA rows. Check "
|
||||||
|
"meta/info.json for observation.images.* features, or pass "
|
||||||
|
"--vlm.camera_key=<key> to seed the cameras list."
|
||||||
|
)
|
||||||
|
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
|
||||||
|
interjections = InterjectionsAndSpeechModule(
|
||||||
|
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||||
|
)
|
||||||
|
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||||
|
writer = LanguageColumnsWriter()
|
||||||
|
validator = StagingValidator(
|
||||||
|
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = Executor(
|
||||||
|
config=cfg,
|
||||||
|
plan=plan,
|
||||||
|
interjections=interjections,
|
||||||
|
vqa=vqa,
|
||||||
|
writer=writer,
|
||||||
|
validator=validator,
|
||||||
|
)
|
||||||
|
summary = executor.run(root)
|
||||||
|
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
|
||||||
|
for phase in summary.phases:
|
||||||
|
logger.info(
|
||||||
|
"annotate: phase=%s processed=%d skipped=%d",
|
||||||
|
phase.name,
|
||||||
|
phase.episodes_processed,
|
||||||
|
phase.episodes_skipped,
|
||||||
|
)
|
||||||
|
if summary.validation_report.warnings:
|
||||||
|
for w in summary.validation_report.warnings:
|
||||||
|
logger.warning(w)
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
if cfg.repo_id is None and cfg.dest_repo_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"--push_to_hub requires --repo_id or --dest_repo_id (the dataset repo to push to)."
|
||||||
|
)
|
||||||
|
_push_to_hub(root, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||||
|
"""Upload the annotated dataset directory to the Hub.
|
||||||
|
|
||||||
|
Pushes to ``cfg.dest_repo_id`` when set, otherwise back to ``cfg.repo_id``.
|
||||||
|
"""
|
||||||
|
from huggingface_hub import HfApi # noqa: PLC0415
|
||||||
|
|
||||||
|
repo_id = cfg.dest_repo_id or cfg.repo_id
|
||||||
|
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||||
|
api = HfApi()
|
||||||
|
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||||
|
api.create_repo(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
private=cfg.push_private,
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||||
|
commit_info = api.upload_folder(
|
||||||
|
folder_path=str(root),
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
commit_message=commit_message,
|
||||||
|
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||||
|
)
|
||||||
|
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||||
|
|
||||||
|
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||||
|
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||||
|
# for tags like ``v3.0``; without a tag it raises
|
||||||
|
# ``RevisionNotFoundError``. Read the version straight from the
|
||||||
|
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||||
|
# actually wrote (no accidental drift if the codebase floor moves).
|
||||||
|
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||||
|
|
||||||
|
info_path = root / "meta" / "info.json"
|
||||||
|
version_tag = CODEBASE_VERSION
|
||||||
|
if info_path.exists():
|
||||||
|
try:
|
||||||
|
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||||
|
|
||||||
|
info = load_json(info_path)
|
||||||
|
ds_version = info.get("codebase_version")
|
||||||
|
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||||
|
version_tag = ds_version
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}", flush=True)
|
||||||
|
revision = getattr(commit_info, "oid", None)
|
||||||
|
tag_kwargs = {
|
||||||
|
"repo_id": repo_id,
|
||||||
|
"tag": version_tag,
|
||||||
|
"repo_type": "dataset",
|
||||||
|
"exist_ok": True,
|
||||||
|
}
|
||||||
|
if revision is not None:
|
||||||
|
tag_kwargs["revision"] = revision
|
||||||
|
|
||||||
|
try:
|
||||||
|
api.create_tag(**tag_kwargs)
|
||||||
|
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(
|
||||||
|
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||||
|
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||||
|
"Run: from huggingface_hub import HfApi; "
|
||||||
|
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
annotate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
tests/annotations/__init__.py
Normal file
0
tests/annotations/__init__.py
Normal file
58
tests/annotations/_helpers.py
Normal file
58
tests/annotations/_helpers.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Helpers shared across annotation-pipeline tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||||
|
|
||||||
|
|
||||||
|
def make_canned_responder(
|
||||||
|
responses_by_marker: dict[str, Any],
|
||||||
|
default: Any = None,
|
||||||
|
) -> StubVlmClient:
|
||||||
|
"""Return a stub that picks a response by inspecting the user prompt.
|
||||||
|
|
||||||
|
For each call the responder examines the last user-message text and
|
||||||
|
returns the response keyed by the first marker substring it contains.
|
||||||
|
Falls back to ``default`` if no marker matches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def responder(messages: list[dict[str, Any]]) -> Any:
|
||||||
|
last_user_text = ""
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") != "user":
|
||||||
|
continue
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
last_user_text = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
last_user_text = block.get("text", "")
|
||||||
|
for marker, response in responses_by_marker.items():
|
||||||
|
if marker in last_user_text:
|
||||||
|
return response
|
||||||
|
return default
|
||||||
|
|
||||||
|
return StubVlmClient(responder=responder)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_vqa_answer(payload: dict[str, Any]) -> str:
|
||||||
|
return json.dumps(payload, sort_keys=True)
|
||||||
51
tests/annotations/conftest.py
Normal file
51
tests/annotations/conftest.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Shared fixtures for annotation-pipeline tests.
|
||||||
|
|
||||||
|
The on-disk dataset builder lives with the other dataset factories in
|
||||||
|
``tests/fixtures/dataset_factories.py`` (:func:`build_annotation_dataset`);
|
||||||
|
these fixtures only wire it into pytest.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||||
|
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||||
|
return build_annotation_dataset(
|
||||||
|
tmp_path / "ds",
|
||||||
|
episode_specs=[
|
||||||
|
(0, 12, "Could you tidy the kitchen please?"),
|
||||||
|
(1, 12, "Please clean up the kitchen"),
|
||||||
|
],
|
||||||
|
fps=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def single_episode_root(tmp_path: Path) -> Path:
|
||||||
|
return build_annotation_dataset(
|
||||||
|
tmp_path / "ds_one",
|
||||||
|
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||||
|
fps=10,
|
||||||
|
)
|
||||||
101
tests/annotations/run_e2e_smoke.py
Normal file
101
tests/annotations/run_e2e_smoke.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||||
|
|
||||||
|
Builds the shared annotation fixture (:func:`build_annotation_dataset`),
|
||||||
|
runs the full annotation pipeline against it with a stub VLM, and prints a
|
||||||
|
short report. This is intentionally not a pytest test — it exercises the
|
||||||
|
CLI plumbing — but it reuses the same on-disk dataset builder as the pytest
|
||||||
|
fixtures so there is no duplicated fixture code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||||
|
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||||
|
from lerobot.annotations.steerable_pipeline.modules import (
|
||||||
|
GeneralVqaModule,
|
||||||
|
InterjectionsAndSpeechModule,
|
||||||
|
PlanSubtasksMemoryModule,
|
||||||
|
)
|
||||||
|
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||||
|
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||||
|
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _stub_responder(messages):
|
||||||
|
text = ""
|
||||||
|
for m in messages:
|
||||||
|
if m.get("role") == "user":
|
||||||
|
content = m.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
text = block.get("text", "")
|
||||||
|
elif isinstance(content, str):
|
||||||
|
text = content
|
||||||
|
if "atomic subtasks" in text:
|
||||||
|
return {
|
||||||
|
"subtasks": [
|
||||||
|
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
|
||||||
|
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
|
||||||
|
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
if "concise hierarchical PLAN" in text:
|
||||||
|
return {"plan": "1. grasp\n2. pour\n3. place"}
|
||||||
|
if "Update the memory" in text:
|
||||||
|
return {"memory": "poured once"}
|
||||||
|
if "acknowledgement the robot" in text:
|
||||||
|
return {"text": "Sure."}
|
||||||
|
if "ONE realistic interruption" in text:
|
||||||
|
return {"interjection": "use less water", "speech": "Using less water."}
|
||||||
|
if "frame-grounded visual question" in text:
|
||||||
|
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
root = build_annotation_dataset(
|
||||||
|
Path(tmp) / "ds",
|
||||||
|
episode_specs=[(0, 30, "Pour water into the cup.")],
|
||||||
|
fps=10,
|
||||||
|
)
|
||||||
|
vlm = StubVlmClient(responder=_stub_responder)
|
||||||
|
cfg = AnnotationPipelineConfig()
|
||||||
|
executor = Executor(
|
||||||
|
config=cfg,
|
||||||
|
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||||
|
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||||
|
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||||
|
writer=LanguageColumnsWriter(),
|
||||||
|
validator=StagingValidator(),
|
||||||
|
)
|
||||||
|
summary = executor.run(root)
|
||||||
|
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
|
||||||
|
print(f"validation: {summary.validation_report.summary()}")
|
||||||
|
print(f"shards rewritten: {len(summary.written_paths)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
179
tests/annotations/test_frames.py
Normal file
179
tests/annotations/test_frames.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Unit tests for :class:`VideoFrameProvider` method bindings.
|
||||||
|
|
||||||
|
These were prompted by a real regression: ``video_for_episode`` was once
|
||||||
|
indented one level too deep so it ended up nested *inside* a module-level
|
||||||
|
helper (after that function's ``return`` statement) — silently dead code
|
||||||
|
that meant production runs with ``use_video_url=False`` would
|
||||||
|
``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The
|
||||||
|
existing module tests didn't catch it because they exercise stub providers.
|
||||||
|
|
||||||
|
The tests below assert on the class itself (not on an instance), so a
|
||||||
|
future reindent regression flips them to red without needing a real
|
||||||
|
LeRobot dataset on disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402
|
||||||
|
VideoFrameProvider,
|
||||||
|
_decode_frames_av,
|
||||||
|
_decode_frames_ffmpeg,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeMeta:
|
||||||
|
"""Minimal metadata stub exposing ``video_keys`` / ``camera_keys``."""
|
||||||
|
|
||||||
|
def __init__(self, video_keys: list[str], image_keys: list[str]) -> None:
|
||||||
|
self.video_keys = video_keys
|
||||||
|
self.camera_keys = [*video_keys, *image_keys]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_camera_key_skips_image_only_cameras(tmp_path: Path, monkeypatch) -> None:
|
||||||
|
"""The default camera must be a *video* key — image-stored cameras have no
|
||||||
|
``videos/<key>/from_timestamp`` and would KeyError in the clip/decode path.
|
||||||
|
|
||||||
|
Regression: a dataset whose first ``camera_keys`` entry was an image-stored
|
||||||
|
camera (e.g. ``observation.images.wrist``) crashed at clip extraction.
|
||||||
|
"""
|
||||||
|
fake = _FakeMeta(
|
||||||
|
video_keys=["observation.images.robot0_agentview_right"],
|
||||||
|
image_keys=["observation.images.wrist"],
|
||||||
|
)
|
||||||
|
import lerobot.datasets.dataset_metadata as meta_mod
|
||||||
|
|
||||||
|
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||||
|
provider = VideoFrameProvider(root=tmp_path)
|
||||||
|
assert provider.camera_key == "observation.images.robot0_agentview_right"
|
||||||
|
assert "observation.images.wrist" not in provider.camera_keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_for_episode_is_a_method_of_videoframeprovider():
|
||||||
|
"""``video_for_episode`` must be a bound method, not nested dead code."""
|
||||||
|
assert callable(getattr(VideoFrameProvider, "video_for_episode", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_clip_path_is_a_method_of_videoframeprovider():
|
||||||
|
"""``episode_clip_path`` is now a method (was a free function reaching
|
||||||
|
into ``provider._meta`` from outside the class)."""
|
||||||
|
assert callable(getattr(VideoFrameProvider, "episode_clip_path", None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_videoframeprovider_has_a_lock_for_concurrent_use():
|
||||||
|
"""A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases
|
||||||
|
concurrently; the cache + warn-flag accesses must be guarded.
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# Fresh-instance check via a minimal fake to avoid touching the hub.
|
||||||
|
# The lock is declared with ``init=False`` and has a default factory,
|
||||||
|
# so a constructed instance must own a real ``threading.Lock``.
|
||||||
|
lock_field = next(
|
||||||
|
(f for f in VideoFrameProvider.__dataclass_fields__.values() if f.name == "_lock"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert lock_field is not None
|
||||||
|
assert lock_field.default_factory is threading.Lock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_video(tmp_path: Path) -> Path:
|
||||||
|
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
|
||||||
|
if shutil.which("ffmpeg") is None:
|
||||||
|
pytest.skip("ffmpeg not available")
|
||||||
|
out = tmp_path / "sample.mp4"
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-f",
|
||||||
|
"lavfi",
|
||||||
|
"-i",
|
||||||
|
"testsrc=duration=3:size=160x120:rate=10",
|
||||||
|
"-pix_fmt",
|
||||||
|
"yuv420p",
|
||||||
|
str(out),
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
|
||||||
|
"""``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision.
|
||||||
|
|
||||||
|
This is the always-available fallback: torchcodec is unusable in some
|
||||||
|
containers and lerobot's ``pyav`` backend routes through the removed
|
||||||
|
``torchvision.io.VideoReader``.
|
||||||
|
"""
|
||||||
|
timestamps = [0.0, 1.0, 2.5]
|
||||||
|
frames = _decode_frames_av(sample_video, timestamps)
|
||||||
|
|
||||||
|
assert len(frames) == len(timestamps)
|
||||||
|
for frame in frames:
|
||||||
|
assert isinstance(frame, torch.Tensor)
|
||||||
|
assert frame.dtype == torch.uint8
|
||||||
|
assert frame.shape == (3, 120, 160)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None:
|
||||||
|
"""Repeated and out-of-order timestamps each resolve to the nearest frame."""
|
||||||
|
frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0])
|
||||||
|
|
||||||
|
assert len(frames) == 3
|
||||||
|
assert torch.equal(frames[0], frames[2])
|
||||||
|
assert not torch.equal(frames[0], frames[1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None:
|
||||||
|
"""A missing video surfaces as an exception the caller can fall back on."""
|
||||||
|
with pytest.raises(Exception): # noqa: B017, PT011
|
||||||
|
_decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_frames_ffmpeg_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
|
||||||
|
"""``_decode_frames_ffmpeg`` shells out to the ffmpeg CLI — the always-
|
||||||
|
available fallback that decodes AV1 and isolates crashes to a child
|
||||||
|
process.
|
||||||
|
"""
|
||||||
|
timestamps = [0.0, 1.0, 2.5]
|
||||||
|
frames = _decode_frames_ffmpeg(sample_video, timestamps)
|
||||||
|
|
||||||
|
assert len(frames) == len(timestamps)
|
||||||
|
for frame in frames:
|
||||||
|
assert isinstance(frame, torch.Tensor)
|
||||||
|
assert frame.dtype == torch.uint8
|
||||||
|
assert frame.shape == (3, 120, 160)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_frames_ffmpeg_raises_on_missing_file(tmp_path: Path) -> None:
|
||||||
|
"""A missing video raises (non-zero ffmpeg exit), never crashes the job."""
|
||||||
|
if shutil.which("ffmpeg") is None:
|
||||||
|
pytest.skip("ffmpeg not available")
|
||||||
|
with pytest.raises(Exception): # noqa: B017, PT011
|
||||||
|
_decode_frames_ffmpeg(tmp_path / "does_not_exist.mp4", [0.0])
|
||||||
355
tests/annotations/test_modules.py
Normal file
355
tests/annotations/test_modules.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Module 1/2/3 unit tests with stubbed VLMs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.config import (
|
||||||
|
InterjectionsConfig,
|
||||||
|
PlanConfig,
|
||||||
|
VqaConfig,
|
||||||
|
)
|
||||||
|
from lerobot.annotations.steerable_pipeline.modules import (
|
||||||
|
GeneralVqaModule,
|
||||||
|
InterjectionsAndSpeechModule,
|
||||||
|
PlanSubtasksMemoryModule,
|
||||||
|
)
|
||||||
|
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||||
|
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||||
|
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||||
|
|
||||||
|
from ._helpers import make_canned_responder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _StubFrameProvider:
|
||||||
|
"""Returns one sentinel object per requested timestamp."""
|
||||||
|
|
||||||
|
sentinel: Any = field(default_factory=lambda: object())
|
||||||
|
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||||
|
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||||
|
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
return list(self.cameras)
|
||||||
|
|
||||||
|
def frames_at(self, record, timestamps, camera_key=None):
|
||||||
|
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||||
|
return [self.sentinel] * len(timestamps)
|
||||||
|
|
||||||
|
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||||
|
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||||
|
n = min(max_frames, len(record.frame_timestamps))
|
||||||
|
return [self.sentinel] * n
|
||||||
|
|
||||||
|
|
||||||
|
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||||
|
def responder(messages):
|
||||||
|
captured.append(list(messages))
|
||||||
|
return reply
|
||||||
|
|
||||||
|
return StubVlmClient(responder=responder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
vlm = make_canned_responder(
|
||||||
|
{
|
||||||
|
"atomic subtasks": {
|
||||||
|
"subtasks": [
|
||||||
|
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||||
|
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||||
|
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"Update the memory": {"memory": "wiped the counter once"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig())
|
||||||
|
record = next(iter_episodes(fixture_dataset_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
rows = staging.read("plan")
|
||||||
|
|
||||||
|
styles = {r["style"] for r in rows}
|
||||||
|
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||||
|
# subtask timestamps must be exact frame timestamps
|
||||||
|
frame_set = set(record.frame_timestamps)
|
||||||
|
for row in rows:
|
||||||
|
assert row["timestamp"] in frame_set
|
||||||
|
# one plan row per subtask boundary; the first lands at t0 and each
|
||||||
|
# plan is the deterministic numbered list of still-todo subtasks
|
||||||
|
plan_rows = sorted((r for r in rows if r["style"] == "plan"), key=lambda r: r["timestamp"])
|
||||||
|
subtask_rows = [r for r in rows if r["style"] == "subtask"]
|
||||||
|
assert len(plan_rows) == len(subtask_rows)
|
||||||
|
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
|
||||||
|
# the t0 plan enumerates all subtasks; later plans shrink
|
||||||
|
assert plan_rows[0]["content"].startswith("1. ")
|
||||||
|
assert len(plan_rows[0]["content"].splitlines()) == len(subtask_rows)
|
||||||
|
assert len(plan_rows[-1]["content"].splitlines()) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
vlm = make_canned_responder(
|
||||||
|
{"acknowledgement the robot": {"text": "Sure, on it."}},
|
||||||
|
)
|
||||||
|
module = InterjectionsAndSpeechModule(
|
||||||
|
vlm=vlm,
|
||||||
|
config=InterjectionsConfig(max_interjections_per_episode=0),
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(fixture_dataset_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
rows = staging.read("interjections")
|
||||||
|
assert len(rows) == 1
|
||||||
|
only = rows[0]
|
||||||
|
assert only["role"] == "assistant"
|
||||||
|
assert only["style"] is None
|
||||||
|
assert only["content"] is None
|
||||||
|
assert only["timestamp"] == record.frame_timestamps[0]
|
||||||
|
assert only["tool_calls"][0]["function"]["name"] == "say"
|
||||||
|
|
||||||
|
|
||||||
|
def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||||
|
fixture_dataset_root: Path, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Module 2 anchors interjections on Module 1's subtask boundaries.
|
||||||
|
|
||||||
|
The executor runs Module 1 first, then Module 2 reads the subtask
|
||||||
|
rows back from the same staging tree (see
|
||||||
|
``_mid_episode_interjections``). Reproduce that contract here by
|
||||||
|
seeding the staging with two subtask rows so a single ``0 → 1``
|
||||||
|
boundary exists for Module 2 to anchor on.
|
||||||
|
"""
|
||||||
|
vlm = make_canned_responder(
|
||||||
|
{
|
||||||
|
"acknowledgement the robot": {"text": "OK."},
|
||||||
|
# Marker matches the distinctive line of
|
||||||
|
# ``module_2_interjection.txt``. The old marker
|
||||||
|
# ("ONE realistic interruption") came from a previous prompt
|
||||||
|
# version that asked for counterfactual interjections; the
|
||||||
|
# current design anchors on subtask boundaries instead, so
|
||||||
|
# the prompt and its marker changed.
|
||||||
|
"Write ONE interjection": {
|
||||||
|
"interjection": "now wipe the counter please",
|
||||||
|
"speech": "On it.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
module = InterjectionsAndSpeechModule(
|
||||||
|
vlm=vlm,
|
||||||
|
config=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||||
|
seed=7,
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(fixture_dataset_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
# Seed Module 1's subtask staging so Module 2 has a boundary to
|
||||||
|
# anchor on (it bails with zero rows when no spans exist — the
|
||||||
|
# production executor guarantees Module 1 ran first).
|
||||||
|
boundary_ts = float(record.frame_timestamps[len(record.frame_timestamps) // 2])
|
||||||
|
staging.write(
|
||||||
|
"plan",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "grasp the sponge",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": float(record.frame_timestamps[0]),
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "wipe the counter",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": boundary_ts,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
rows = staging.read("interjections")
|
||||||
|
|
||||||
|
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||||
|
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||||
|
assert len(interjections) == 1
|
||||||
|
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
|
||||||
|
inter_t = interjections[0]["timestamp"]
|
||||||
|
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||||
|
|
||||||
|
|
||||||
|
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
|
payload = {
|
||||||
|
"question": "How many cups?",
|
||||||
|
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||||
|
}
|
||||||
|
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||||
|
module = GeneralVqaModule(
|
||||||
|
vlm=vlm,
|
||||||
|
config=VqaConfig(vqa_emission_hz=1.0, K=3),
|
||||||
|
seed=1,
|
||||||
|
frame_provider=_StubFrameProvider(cameras=("observation.images.top", "observation.images.wrist")),
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(single_episode_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
rows = staging.read("vqa")
|
||||||
|
# every vqa row must carry a camera tag and one of the configured cameras
|
||||||
|
for r in rows:
|
||||||
|
assert r["style"] == "vqa"
|
||||||
|
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||||
|
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||||
|
user_keys = [(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
||||||
|
assistant_keys = [
|
||||||
|
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "assistant" and r["style"] == "vqa"
|
||||||
|
]
|
||||||
|
assert len(user_keys) == len(set(user_keys))
|
||||||
|
assert len(assistant_keys) == len(set(assistant_keys))
|
||||||
|
# both cameras must be represented
|
||||||
|
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||||
|
# every emitted timestamp must be an exact source frame timestamp
|
||||||
|
frame_set = set(record.frame_timestamps)
|
||||||
|
for ts, _ in user_keys + assistant_keys:
|
||||||
|
assert ts in frame_set
|
||||||
|
|
||||||
|
|
||||||
|
def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
"""Module 1 sends one ``type=video`` block covering the whole episode."""
|
||||||
|
captured: list[list[dict[str, Any]]] = []
|
||||||
|
payload = {
|
||||||
|
"subtasks": [
|
||||||
|
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.5},
|
||||||
|
{"text": "wipe the counter", "start": 0.5, "end": 1.1},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
plan_payload = {"plan": "1. grasp\n2. wipe"}
|
||||||
|
memory_payload = {"memory": "wiped once"}
|
||||||
|
|
||||||
|
def responder(messages):
|
||||||
|
captured.append(list(messages))
|
||||||
|
text = ""
|
||||||
|
for m in messages:
|
||||||
|
for block in m.get("content", []):
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
text = block.get("text", "")
|
||||||
|
if "concise hierarchical PLAN" in text:
|
||||||
|
return plan_payload
|
||||||
|
if "Update the memory" in text:
|
||||||
|
return memory_payload
|
||||||
|
return payload
|
||||||
|
|
||||||
|
provider = _StubFrameProvider()
|
||||||
|
module = PlanSubtasksMemoryModule(
|
||||||
|
vlm=StubVlmClient(responder=responder),
|
||||||
|
# Disable the rephrasings sub-prompt so the test's only video-bearing
|
||||||
|
# call is the subtask one — keeps the assertions below focused on
|
||||||
|
# ``_generate_subtasks`` rather than fighting the order of unrelated
|
||||||
|
# text-only Module-1 sub-prompts.
|
||||||
|
config=PlanConfig(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
|
||||||
|
frame_provider=provider,
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(fixture_dataset_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
|
||||||
|
# Find the call carrying the subtask prompt rather than blindly taking
|
||||||
|
# captured[0] — Module 1 issues several sub-prompts and their order is
|
||||||
|
# not part of the contract.
|
||||||
|
assert captured, "no VLM calls made"
|
||||||
|
|
||||||
|
def _prompt_text(messages):
|
||||||
|
for m in messages:
|
||||||
|
for block in m.get("content", []):
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
subtask_calls = [m for m in captured if "atomic subtasks" in _prompt_text(m)]
|
||||||
|
assert len(subtask_calls) == 1, "expected exactly one subtask-prompt VLM call"
|
||||||
|
content = subtask_calls[0][0]["content"]
|
||||||
|
video_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "video"]
|
||||||
|
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||||
|
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||||
|
assert len(video_blocks) == 1, f"expected exactly 1 video block, got {content}"
|
||||||
|
assert image_blocks == [], "subtask prompt must not mix image blocks with the video block"
|
||||||
|
assert len(text_blocks) == 1
|
||||||
|
# video block must wrap a list of frames covering the episode
|
||||||
|
assert isinstance(video_blocks[0]["video"], list)
|
||||||
|
assert len(video_blocks[0]["video"]) <= 5
|
||||||
|
# provider is called with target_count = min(duration * fps, max). With
|
||||||
|
# fps=10 on a ~1s episode that requests >max, so max=5 wins.
|
||||||
|
assert provider.video_calls and provider.video_calls[0][0] == record.episode_index
|
||||||
|
assert provider.video_calls[0][1] <= 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
|
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||||
|
captured: list[list[dict[str, Any]]] = []
|
||||||
|
payload = {
|
||||||
|
"question": "How many cups?",
|
||||||
|
"answer": {"label": "cup", "count": 1},
|
||||||
|
}
|
||||||
|
provider = _StubFrameProvider()
|
||||||
|
module = GeneralVqaModule(
|
||||||
|
vlm=_spy_responder(captured, payload),
|
||||||
|
config=VqaConfig(vqa_emission_hz=1.0, K=1),
|
||||||
|
seed=0,
|
||||||
|
frame_provider=provider,
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(single_episode_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
|
||||||
|
assert captured, "no VLM calls made"
|
||||||
|
for messages in captured:
|
||||||
|
content = messages[0]["content"]
|
||||||
|
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||||
|
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||||
|
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||||
|
assert image_blocks[0]["image"] is provider.sentinel
|
||||||
|
assert len(text_blocks) == 1
|
||||||
|
# provider was called once per emission per camera with the exact emission timestamp
|
||||||
|
for ep_idx, ts_tuple, camera in provider.calls:
|
||||||
|
assert ep_idx == record.episode_index
|
||||||
|
assert len(ts_tuple) == 1
|
||||||
|
assert ts_tuple[0] in record.frame_timestamps
|
||||||
|
assert camera in provider.cameras
|
||||||
|
|
||||||
|
|
||||||
|
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
|
payload = {
|
||||||
|
"question": "Where is the cup?",
|
||||||
|
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
|
||||||
|
}
|
||||||
|
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||||
|
module = GeneralVqaModule(
|
||||||
|
vlm=vlm,
|
||||||
|
config=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||||
|
seed=2,
|
||||||
|
frame_provider=_StubFrameProvider(),
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(single_episode_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
rows = staging.read("vqa")
|
||||||
|
for row in rows:
|
||||||
|
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||||
|
decoded = json.loads(row["content"])
|
||||||
|
assert "detections" in decoded
|
||||||
175
tests/annotations/test_pipeline_recipe_render.py
Normal file
175
tests/annotations/test_pipeline_recipe_render.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""End-to-end smoke: pipeline output → PR 1 canonical recipe rendering."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.config import (
|
||||||
|
AnnotationPipelineConfig,
|
||||||
|
InterjectionsConfig,
|
||||||
|
PlanConfig,
|
||||||
|
VqaConfig,
|
||||||
|
)
|
||||||
|
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||||
|
from lerobot.annotations.steerable_pipeline.modules import (
|
||||||
|
GeneralVqaModule,
|
||||||
|
InterjectionsAndSpeechModule,
|
||||||
|
PlanSubtasksMemoryModule,
|
||||||
|
)
|
||||||
|
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||||
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||||
|
from lerobot.datasets.language_render import render_sample
|
||||||
|
|
||||||
|
from ._helpers import make_canned_responder
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pr1_style_blend_recipe() -> TrainingRecipe:
|
||||||
|
"""Inline blend recipe that consumes every style this pipeline produces.
|
||||||
|
|
||||||
|
PR 1 used to ship ``src/lerobot/configs/recipes/pi05_hirobot.yaml`` as
|
||||||
|
a canonical example, but that file was dropped during PR 1 review. The
|
||||||
|
cross-PR contract this test guards is "the recipe DSL can render
|
||||||
|
non-empty messages from pipeline output", which doesn't require a
|
||||||
|
specific YAML — so we build the equivalent blend in code.
|
||||||
|
"""
|
||||||
|
return TrainingRecipe(
|
||||||
|
blend={
|
||||||
|
"low_level_execution": TrainingRecipe(
|
||||||
|
weight=0.35,
|
||||||
|
messages=[
|
||||||
|
MessageTurn(
|
||||||
|
role="user",
|
||||||
|
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
|
||||||
|
stream="high_level",
|
||||||
|
),
|
||||||
|
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"user_interjection_response": TrainingRecipe(
|
||||||
|
weight=0.16,
|
||||||
|
bindings={
|
||||||
|
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||||
|
"interjection": "emitted_at(t, style=interjection)",
|
||||||
|
},
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(
|
||||||
|
role="user",
|
||||||
|
content="${interjection}",
|
||||||
|
stream="high_level",
|
||||||
|
if_present="interjection",
|
||||||
|
),
|
||||||
|
MessageTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="${plan}",
|
||||||
|
stream="high_level",
|
||||||
|
target=True,
|
||||||
|
if_present="plan",
|
||||||
|
tool_calls_from="speech",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_executor() -> Executor:
|
||||||
|
vlm = make_canned_responder(
|
||||||
|
{
|
||||||
|
"atomic subtasks": {
|
||||||
|
"subtasks": [
|
||||||
|
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
|
||||||
|
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
|
||||||
|
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"concise hierarchical PLAN": {"plan": "1. grasp\n2. pour\n3. place"},
|
||||||
|
"Update the memory": {"memory": "poured once"},
|
||||||
|
"acknowledgement the robot": {"text": "Sure."},
|
||||||
|
"ONE realistic interruption": {
|
||||||
|
"interjection": "use less water",
|
||||||
|
"speech": "Using less water.",
|
||||||
|
},
|
||||||
|
"frame-grounded visual question": {
|
||||||
|
"question": "How many cups?",
|
||||||
|
"answer": {"label": "cup", "count": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
config = AnnotationPipelineConfig(
|
||||||
|
plan=PlanConfig(),
|
||||||
|
interjections=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||||
|
vqa=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||||
|
)
|
||||||
|
return Executor(
|
||||||
|
config=config,
|
||||||
|
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||||
|
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||||
|
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||||
|
writer=LanguageColumnsWriter(),
|
||||||
|
validator=StagingValidator(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||||
|
single_episode_root: Path,
|
||||||
|
) -> None:
|
||||||
|
executor = _build_executor()
|
||||||
|
summary = executor.run(single_episode_root)
|
||||||
|
# validator may emit warnings but no errors for the synthetic fixture
|
||||||
|
assert summary.validation_report.ok, summary.validation_report.summary()
|
||||||
|
|
||||||
|
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
|
||||||
|
persistent_lists = table.column("language_persistent").to_pylist()
|
||||||
|
events_lists = table.column("language_events").to_pylist()
|
||||||
|
timestamps = table.column("timestamp").to_pylist()
|
||||||
|
|
||||||
|
recipe = _build_pr1_style_blend_recipe()
|
||||||
|
|
||||||
|
rendered_any = False
|
||||||
|
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||||
|
result = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=persistent,
|
||||||
|
events=events,
|
||||||
|
t=float(ts),
|
||||||
|
sample_idx=0,
|
||||||
|
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
if result["messages"]:
|
||||||
|
rendered_any = True
|
||||||
|
assert result["target_message_indices"]
|
||||||
|
break
|
||||||
|
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
|
||||||
|
|
||||||
|
# Sanity: speech atom appears in events column intact
|
||||||
|
flat_events = [r for ev in events_lists for r in ev]
|
||||||
|
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
|
||||||
|
assert speech_rows
|
||||||
|
say = speech_rows[0]["tool_calls"][0]
|
||||||
|
assert say["function"]["name"] == "say"
|
||||||
|
assert isinstance(say["function"]["arguments"]["text"], str)
|
||||||
|
# PR 2 no longer writes a ``tools`` column — the say schema lives as a
|
||||||
|
# constant (``SAY_TOOL_SCHEMA``) so PR 1's row struct is the single
|
||||||
|
# source of truth for the v3.1 schema.
|
||||||
|
assert "tools" not in table.column_names
|
||||||
125
tests/annotations/test_validator.py
Normal file
125
tests/annotations/test_validator.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Validator behavior tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||||
|
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||||
|
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import speech_atom
|
||||||
|
|
||||||
|
|
||||||
|
def _validate(root: Path, staging_dir: Path):
|
||||||
|
records = list(iter_episodes(root))
|
||||||
|
return StagingValidator().validate(records, staging_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
EpisodeStaging(staging_dir, 0).write(
|
||||||
|
"vqa",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": 9.999, # not on any 10 fps frame
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
report = _validate(fixture_dataset_root, staging_dir)
|
||||||
|
assert not report.ok
|
||||||
|
assert any("does not match any source frame timestamp" in e for e in report.errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
EpisodeStaging(staging_dir, 0).write(
|
||||||
|
"interjections",
|
||||||
|
[
|
||||||
|
speech_atom(0.0, "Got it."),
|
||||||
|
# interjection at 0.3s with NO paired speech
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "skip it",
|
||||||
|
"style": "interjection",
|
||||||
|
"timestamp": 0.3,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
report = _validate(fixture_dataset_root, staging_dir)
|
||||||
|
assert not report.ok
|
||||||
|
assert any("paired speech" in e for e in report.errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
EpisodeStaging(staging_dir, 0).write(
|
||||||
|
"plan",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "1. do x",
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "do x",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
EpisodeStaging(staging_dir, 0).write(
|
||||||
|
"interjections",
|
||||||
|
[
|
||||||
|
speech_atom(0.0, "Got it."),
|
||||||
|
speech_atom(0.4, "Replanning."),
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "replan",
|
||||||
|
"style": "interjection",
|
||||||
|
"timestamp": 0.4,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
report = _validate(fixture_dataset_root, staging_dir)
|
||||||
|
# missing co-timestamped plan refresh at 0.4s → error
|
||||||
|
assert not report.ok
|
||||||
|
assert any("co-timestamped plan update" in e for e in report.errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
EpisodeStaging(staging_dir, 0).write(
|
||||||
|
"plan",
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
report = _validate(fixture_dataset_root, staging_dir)
|
||||||
|
assert not report.ok
|
||||||
|
assert any("plan emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||||
350
tests/annotations/test_writer.py
Normal file
350
tests/annotations/test_writer.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Writer correctness tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||||
|
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import (
|
||||||
|
LanguageColumnsWriter,
|
||||||
|
speech_atom,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stage_episode(
|
||||||
|
staging_dir: Path,
|
||||||
|
episode_index: int,
|
||||||
|
*,
|
||||||
|
plan: list[dict] | None = None,
|
||||||
|
interjections: list[dict] | None = None,
|
||||||
|
vqa: list[dict] | None = None,
|
||||||
|
) -> None:
|
||||||
|
staging = EpisodeStaging(staging_dir, episode_index)
|
||||||
|
if plan is not None:
|
||||||
|
staging.write("plan", plan)
|
||||||
|
if interjections is not None:
|
||||||
|
staging.write("interjections", interjections)
|
||||||
|
if vqa is not None:
|
||||||
|
staging.write("vqa", vqa)
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
"""Every frame in an episode has a byte-identical persistent list."""
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
plan=[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "grasp the sponge",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "1. wipe\n2. dry",
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "wiped the counter",
|
||||||
|
"style": "memory",
|
||||||
|
"timestamp": 0.5,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
|
||||||
|
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||||
|
persistent = table.column("language_persistent").to_pylist()
|
||||||
|
first = persistent[0]
|
||||||
|
assert first # non-empty
|
||||||
|
for row in persistent:
|
||||||
|
assert row == first, "persistent slice must be byte-identical across all frames"
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
interjections=[
|
||||||
|
speech_atom(0.0, "Got it."),
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "skip the dishes",
|
||||||
|
"style": "interjection",
|
||||||
|
"timestamp": 0.5,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
speech_atom(0.5, "Skipping the dishes."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
|
||||||
|
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||||
|
timestamps = table.column("timestamp").to_pylist()
|
||||||
|
events = table.column("language_events").to_pylist()
|
||||||
|
for ts, ev in zip(timestamps, events, strict=True):
|
||||||
|
if abs(ts - 0.0) < 1e-9:
|
||||||
|
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
|
||||||
|
elif abs(ts - 0.5) < 1e-9:
|
||||||
|
assert any(r.get("style") == "interjection" for r in ev), ev
|
||||||
|
assert any(r.get("style") is None for r in ev), ev
|
||||||
|
else:
|
||||||
|
assert ev == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
plan=[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "do X",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "1. do X",
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "did X",
|
||||||
|
"style": "memory",
|
||||||
|
"timestamp": 0.3,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
interjections=[
|
||||||
|
speech_atom(0.0, "OK"),
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "wait",
|
||||||
|
"style": "interjection",
|
||||||
|
"timestamp": 0.2,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
speech_atom(0.2, "Waiting"),
|
||||||
|
],
|
||||||
|
vqa=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "where is the cup?",
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": 0.4,
|
||||||
|
"camera": "observation.images.front",
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": json.dumps(
|
||||||
|
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
|
||||||
|
sort_keys=True,
|
||||||
|
),
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": 0.4,
|
||||||
|
"camera": "observation.images.front",
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||||
|
|
||||||
|
persistent = table.column("language_persistent").to_pylist()[0]
|
||||||
|
persistent_styles = {r["style"] for r in persistent}
|
||||||
|
assert persistent_styles == {"subtask", "plan", "memory"}
|
||||||
|
|
||||||
|
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
|
||||||
|
event_styles = {r.get("style") for r in all_events}
|
||||||
|
assert event_styles == {None, "interjection", "vqa"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
plan=[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "do X",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
writer = LanguageColumnsWriter()
|
||||||
|
writer.write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
|
||||||
|
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
|
||||||
|
table_a = pq.read_table(path)
|
||||||
|
assert "subtask_index" not in table_a.column_names
|
||||||
|
assert "language_persistent" in table_a.column_names
|
||||||
|
assert "language_events" in table_a.column_names
|
||||||
|
# The writer no longer emits a dataset-level ``tools`` column; the
|
||||||
|
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
|
||||||
|
# so the parquet stays small and PR 2 doesn't extend PR 1's schema.
|
||||||
|
assert "tools" not in table_a.column_names
|
||||||
|
|
||||||
|
# second pass — must produce identical bytes for the language columns
|
||||||
|
records_again = list(iter_episodes(fixture_dataset_root))
|
||||||
|
writer.write_all(records_again, staging_dir, fixture_dataset_root)
|
||||||
|
table_b = pq.read_table(path)
|
||||||
|
assert (
|
||||||
|
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
|
||||||
|
)
|
||||||
|
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
|
||||||
|
"""``_normalize_persistent_row`` must reject any non-persistent style."""
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="non-persistent style"):
|
||||||
|
_normalize_persistent_row(
|
||||||
|
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_normalize_rejects_misrouted_event_style() -> None:
|
||||||
|
"""``_normalize_event_row`` must reject any persistent style."""
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
|
||||||
|
|
||||||
|
|
||||||
|
def test_say_tool_schema_constant_is_well_formed() -> None:
|
||||||
|
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
|
||||||
|
``tools`` column — chat-template consumers import them directly.
|
||||||
|
"""
|
||||||
|
from lerobot.annotations.steerable_pipeline.writer import (
|
||||||
|
DEFAULT_TOOLS,
|
||||||
|
SAY_TOOL_SCHEMA,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
|
||||||
|
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
|
||||||
|
params = SAY_TOOL_SCHEMA["function"]["parameters"]
|
||||||
|
assert params["properties"]["text"]["type"] == "string"
|
||||||
|
assert params["required"] == ["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
|
"""Re-running on a parquet that already has a legacy ``tools`` column
|
||||||
|
must drop it cleanly so reruns converge to the v3.1 schema.
|
||||||
|
"""
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
plan=[
|
||||||
|
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||||
|
assert "tools" not in table.column_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||||
|
fixture_dataset_root: Path, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Annotated parquet columns must be declared in ``meta/info.json``.
|
||||||
|
|
||||||
|
``LeRobotDataset`` loads non-streaming datasets by casting parquet
|
||||||
|
against metadata-derived HF features. If the annotation writer adds
|
||||||
|
language columns but metadata stays stale, that cast fails with a column
|
||||||
|
mismatch.
|
||||||
|
"""
|
||||||
|
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||||
|
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||||
|
from lerobot.datasets.io_utils import load_info, load_nested_dataset
|
||||||
|
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
|
||||||
|
|
||||||
|
info_path = fixture_dataset_root / "meta" / "info.json"
|
||||||
|
info = json.loads(info_path.read_text())
|
||||||
|
info["features"] = {
|
||||||
|
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
}
|
||||||
|
info_path.write_text(json.dumps(info, indent=2))
|
||||||
|
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
plan=[
|
||||||
|
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
|
||||||
|
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
|
||||||
|
|
||||||
|
synced = load_info(fixture_dataset_root)
|
||||||
|
for key, feature in language_feature_info().items():
|
||||||
|
assert synced["features"][key] == feature
|
||||||
|
|
||||||
|
hf_features = get_hf_features_from_features(synced["features"])
|
||||||
|
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
|
||||||
|
|
||||||
|
assert LANGUAGE_PERSISTENT in dataset.column_names
|
||||||
|
assert LANGUAGE_EVENTS in dataset.column_names
|
||||||
|
assert len(dataset) == 24
|
||||||
|
|
||||||
|
|
||||||
|
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||||
|
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||||
|
assert atom["role"] == "assistant"
|
||||||
|
assert atom["style"] is None
|
||||||
|
assert atom["content"] is None
|
||||||
|
assert atom["timestamp"] == 2.5
|
||||||
|
assert isinstance(atom["tool_calls"], list)
|
||||||
|
call = atom["tool_calls"][0]
|
||||||
|
assert call["type"] == "function"
|
||||||
|
assert call["function"]["name"] == "say"
|
||||||
|
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"
|
||||||
61
tests/fixtures/dataset_factories.py
vendored
61
tests/fixtures/dataset_factories.py
vendored
@@ -552,3 +552,64 @@ def lerobot_dataset_factory(
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||||
|
|
||||||
|
|
||||||
|
def build_annotation_dataset(
|
||||||
|
root: Path,
|
||||||
|
episode_specs: list[tuple[int, int, str]],
|
||||||
|
*,
|
||||||
|
fps: int = 10,
|
||||||
|
) -> Path:
|
||||||
|
"""Build a minimal LeRobot-shaped dataset on disk for annotation tests.
|
||||||
|
|
||||||
|
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||||
|
Each episode is written to its own
|
||||||
|
``data/chunk-000/file-{ep:03d}.parquet`` so the writer's per-shard
|
||||||
|
rewrite path is exercised. The dataset carries the minimum
|
||||||
|
``meta/tasks.parquet`` + ``meta/info.json`` the reader / executor need;
|
||||||
|
it has no videos, so the modules fall back to text-only prompts.
|
||||||
|
|
||||||
|
Shared by the annotation-pipeline pytest fixtures (``tests/annotations/
|
||||||
|
conftest.py``) and the opt-in E2E smoke run so the fixture shape lives
|
||||||
|
in exactly one place.
|
||||||
|
"""
|
||||||
|
from lerobot.datasets.io_utils import write_tasks
|
||||||
|
from lerobot.utils.io_utils import write_json
|
||||||
|
|
||||||
|
data_dir = root / "data" / "chunk-000"
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
tasks: dict[int, str] = {}
|
||||||
|
for episode_index, num_frames, task_text in episode_specs:
|
||||||
|
if task_text not in tasks.values():
|
||||||
|
tasks[len(tasks)] = task_text
|
||||||
|
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||||
|
frame = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"episode_index": [episode_index] * num_frames,
|
||||||
|
"frame_index": list(range(num_frames)),
|
||||||
|
"timestamp": [round(i / fps, 6) for i in range(num_frames)],
|
||||||
|
"task_index": [task_index] * num_frames,
|
||||||
|
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||||
|
}
|
||||||
|
)
|
||||||
|
frame.to_parquet(data_dir / f"file-{episode_index:03d}.parquet", index=False)
|
||||||
|
|
||||||
|
# Canonical tasks frame: indexed by task string with a ``task_index``
|
||||||
|
# column, matching what ``lerobot.datasets.io_utils.load_tasks`` expects.
|
||||||
|
tasks_df = pd.DataFrame(
|
||||||
|
{"task_index": list(tasks.keys())},
|
||||||
|
index=pd.Index(list(tasks.values()), name="task"),
|
||||||
|
)
|
||||||
|
write_tasks(tasks_df, root)
|
||||||
|
|
||||||
|
write_json(
|
||||||
|
{
|
||||||
|
"codebase_version": "v3.1",
|
||||||
|
"fps": fps,
|
||||||
|
"features": {},
|
||||||
|
"total_episodes": len(episode_specs),
|
||||||
|
},
|
||||||
|
root / "meta" / "info.json",
|
||||||
|
)
|
||||||
|
return root
|
||||||
|
|||||||
51
tests/scripts/test_lerobot_annotate.py
Normal file
51
tests/scripts/test_lerobot_annotate.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
|
||||||
|
def test_push_to_hub_tags_uploaded_dataset_revision(tmp_path, monkeypatch):
|
||||||
|
from lerobot.scripts.lerobot_annotate import _push_to_hub
|
||||||
|
|
||||||
|
root = tmp_path / "dataset"
|
||||||
|
(root / "meta").mkdir(parents=True)
|
||||||
|
(root / "meta" / "info.json").write_text(json.dumps({"codebase_version": "v3.0"}))
|
||||||
|
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
class FakeHfApi:
|
||||||
|
def create_repo(self, **kwargs):
|
||||||
|
calls["create_repo"] = kwargs
|
||||||
|
|
||||||
|
def upload_folder(self, **kwargs):
|
||||||
|
calls["upload_folder"] = kwargs
|
||||||
|
return SimpleNamespace(oid="abc123")
|
||||||
|
|
||||||
|
def create_tag(self, **kwargs):
|
||||||
|
calls["create_tag"] = kwargs
|
||||||
|
|
||||||
|
monkeypatch.setattr("huggingface_hub.HfApi", FakeHfApi)
|
||||||
|
|
||||||
|
cfg = SimpleNamespace(
|
||||||
|
repo_id="source/dataset",
|
||||||
|
dest_repo_id="annotated/dataset",
|
||||||
|
push_private=True,
|
||||||
|
push_commit_message=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
_push_to_hub(root, cfg)
|
||||||
|
|
||||||
|
assert calls["create_repo"] == {
|
||||||
|
"repo_id": "annotated/dataset",
|
||||||
|
"repo_type": "dataset",
|
||||||
|
"private": True,
|
||||||
|
"exist_ok": True,
|
||||||
|
}
|
||||||
|
assert calls["upload_folder"]["repo_id"] == "annotated/dataset"
|
||||||
|
assert calls["create_tag"] == {
|
||||||
|
"repo_id": "annotated/dataset",
|
||||||
|
"tag": "v3.0",
|
||||||
|
"repo_type": "dataset",
|
||||||
|
"exist_ok": True,
|
||||||
|
"revision": "abc123",
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user